support pipeline partition with shared initializer (#4321)

* support bert partition with shared initializer

* address feedback

* address feedback

* address feedback

* add more test

* remove bert-tiny model

* address feedback

* address function comment

* move CreateNodeArg to graph_utils

* rename function name

* rename function name

* fix windows build

* fix windows type conversion warning

* add function comment
This commit is contained in:
Xueyun Zhu 2020-07-14 17:21:40 -07:00 committed by GitHub
parent 1ebe598286
commit 7d96960ec8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 739 additions and 306 deletions

View file

@ -140,6 +140,25 @@ class Node {
return common::Status::OK();
}
/**
Helper to iterate through the container returned by #MutableInputDefs() or #MutableOutputDefs() and call the provided function.
@param node_args Collection of NodeArgs returned by #MutableInputDefs() or #MutableOutputDefs()
@param func Function to call for each valid NodeArg in the node_args. The function is called with the NodeArg
and the index number in the container.
@returns common::Status with success or error information.
@remarks Returns immediately on error.
*/
static common::Status ForEachMutableWithIndex(std::vector<NodeArg*>& node_args,
std::function<common::Status(NodeArg& arg, size_t index)> func) {
for (size_t index = 0; index < node_args.size(); ++index) {
auto arg = node_args[index];
if (!arg->Exists())
continue;
ORT_RETURN_IF_ERROR(func(*arg, index));
}
return common::Status::OK();
}
/** Gets the count of arguments for each of the Node's explicit inputs. */
const std::vector<int>& InputArgCount() const noexcept { return definitions_.input_arg_count; }
@ -507,6 +526,9 @@ class Graph {
/** Remove the initializer tensor with the provided name from the Graph. */
void RemoveInitializedTensor(const std::string& tensor_name);
/** Check if a given name is an initializer tensor's name in this graph. */
bool IsInitializedTensor(const std::string& name) const;
/** Replaces the initializer tensor with the same name as the given initializer tensor.
The replacement initializer tensor must have the same type and shape as the existing initializer tensor.
@ -635,15 +657,14 @@ class Graph {
if (iter != node_args_.end()) {
return *(iter->second);
}
auto result = node_args_.insert(std::make_pair(name, onnxruntime::make_unique<NodeArg>(name, p_arg_type)));
return *(result.first->second);
}
/** Generate a unique name.in this Graph for a NodeArg */
/** Generate a unique name in this Graph for a NodeArg */
std::string GenerateNodeArgName(const std::string& base_name);
/** Generate a unique name.in this Graph for a Node */
/** Generate a unique name in this Graph for a Node */
std::string GenerateNodeName(const std::string& base_name);
/** Add a Node to this Graph.

View file

@ -253,6 +253,21 @@ Path& Path::Append(const Path& other) {
return *this;
}
Path& Path::Concat(const PathString& string) {
components_.back() += string;
return *this;
}
Path& Path::ConcatIndex(const int index) {
#ifdef _WIN32
auto index_str = std::to_wstring(index);
#else
auto index_str = std::to_string(index);
#endif
components_.back() += index_str;
return *this;
}
Status RelativePath(const Path& src, const Path& dst, Path& rel) {
ORT_RETURN_IF_NOT(
src.GetRootPathString() == dst.GetRootPathString(),

View file

@ -61,6 +61,19 @@ class Path {
* The algorithm should model that of std::filesystem::path::append().
*/
Path& Append(const Path& other);
/**
* Concatenates the current path and the argument string.
* Unlike with Append() or operator/=, additional directory separators are never introduced.
*/
Path& Concat(const PathString& string);
/**
* Concatenates an index by the end of current path.
* Similar to Concat() except the argument is an index.
*/
Path& ConcatIndex(const int index);
/** Equivalent to this->Append(other). */
Path& operator/=(const Path& other) {
return Append(other);

View file

@ -2331,6 +2331,10 @@ static void RemoveRepeatedFieldEntry(T& repeated_field, const TIter& entry_to_re
}
}
bool Graph::IsInitializedTensor(const std::string& name) const {
return name_to_initial_tensor_.count(name) > 0;
}
void Graph::RemoveInitializedTensor(const std::string& tensor_name) {
bool found = false;
auto iter = name_to_initial_tensor_.find(tensor_name);

View file

@ -181,7 +181,7 @@ static void RemoveGraphEdges(Graph& graph, const std::vector<GraphEdge>& edges)
}
/** Given a graph, a list of edges, and a NodeArg name, checks if each of the edges provides an implicit input
to a subgraph. If so, it checks if there is no clash of the given NodeArg name in each of the subgraphs.
to a subgraph. If so, it checks if there is no clash of the given NodeArg name in each of the subgraphs.
This is important when removing a node with this NodeArg as input. */
static bool CanUpdateImplicitInputNameInSubgraphs(const Graph& graph,
const std::vector<GraphEdge>& output_edges,
@ -321,7 +321,7 @@ const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const s
return iter == attrs.end() ? nullptr : &iter->second;
}
/** Checks for nodes with >= 1 outputs, if only one of the outputs is input to downstream Operators.
/** Checks for nodes with >= 1 outputs, if only one of the outputs is input to downstream Operators.
Returns the name of the single used output in output_name. */
static bool IsOnlyOneOutputUsed(const Graph& graph, const Node& node, const std::string*& output_name) {
const int unassigned = -1;
@ -807,5 +807,10 @@ bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& start_node) {
}
return true;
}
NodeArg& CreateNodeArg(Graph& graph, const NodeArg& base_arg) {
return graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(base_arg.Name()), base_arg.TypeAsProto());
}
} // namespace graph_utils
} // namespace onnxruntime

View file

@ -34,27 +34,27 @@ bool IsOutputUsed(const Node& node, int index);
/** Returns true if the graph has the given input.*/
bool IsGraphInput(const Graph& graph, const NodeArg* input);
/** returns true if 'name' is an initializer in 'graph', or an ancestor graph if check_outer_scope is true.
/** returns true if 'name' is an initializer in 'graph', or an ancestor graph if check_outer_scope is true.
@param check_outer_scope If true and 'graph' is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'.
*/
bool IsInitializer(const Graph& graph, const std::string& name, bool check_outer_scope);
/** returns true if 'name' is an initializer, and is constant and cannot be overridden at runtime.
/** returns true if 'name' is an initializer, and is constant and cannot be overridden at runtime.
@param check_outer_scope If true and 'graph' is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'.
*/
bool IsConstantInitializer(const Graph& graph, const std::string& name, bool check_outer_scope = true);
/** returns the initializer's TensorProto if 'name' is an initializer, is constant and
/** returns the initializer's TensorProto if 'name' is an initializer, is constant and
cannot be overridden at runtime. If the initializer is not found or is not constant, a nullptr is returned.
@param check_outer_scope If true and the graph is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'.
*/
const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const Graph& graph, const std::string& name,
bool check_outer_scope = true);
/** Add a new initializer to 'graph'.
/** Add a new initializer to 'graph'.
Checks that new_initializer does not already exist in 'graph' before adding it.
@returns The NodeArg for the new initializer.
@remarks No matching graph input is created, so the initializer will be constant.
@returns The NodeArg for the new initializer.
@remarks No matching graph input is created, so the initializer will be constant.
*/
NodeArg& AddInitializer(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer);
@ -127,7 +127,7 @@ bool RemoveNode(Graph& graph, Node& node);
/** Tests if we can remove a node and replace its output with an initializer.
Conditions:
- Only one of the node's outputs is used by downstream operators or as a graph output
- Only one of the node's outputs is used by downstream operators or as a graph output
- multiple edges for the single used output are allowed
- If the node produces a graph output the initializer_name must be the same as the node's output name
- otherwise the required graph output will not be produced
@ -140,20 +140,20 @@ bool CanReplaceNodeWithInitializer(const Graph& graph, const Node& node, const s
See CanReplaceNodeWithInitializer for the conditions that must be satisfied in order to remove the node.*/
bool ReplaceNodeWithInitializer(Graph& graph, Node& node, NodeArg& replacement);
/** Removes all output edges from the given Node of the Graph.
/** Removes all output edges from the given Node of the Graph.
This should probably be elevated to the Graph API eventually. */
size_t RemoveNodeOutputEdges(Graph& graph, Node& node);
/** Replaces the input to nodes that are downstream from 'node', which was being provided by an output of 'node',
/** Replaces the input to nodes that are downstream from 'node', which was being provided by an output of 'node',
with an output from a different node. Moves the output edges from 'node' for 'output_idx' to the replacement node.
@param replacement The node providing the replacement output.
@param replacement_output_idx The index of the output from 'replacement' to use.
@param replacement_output_idx The index of the output from 'replacement' to use.
e.g. Node A produces outputs A1 and A2.
Node B consumes A2 (edge between A and B for A2) and produces B1.
e.g. Node A produces outputs A1 and A2.
Node B consumes A2 (edge between A and B for A2) and produces B1.
Node C consumes B1 (edge between B and C for B1).
If Node B was determined to not be needed, you would call ReplaceDownstreamNodeInput(graph, B, 0, A, 1)
If Node B was determined to not be needed, you would call ReplaceDownstreamNodeInput(graph, B, 0, A, 1)
to replace B1 (output index 0 for node B) with A2 (output index 1 for node A) as input to the downstream node C.
The edge that existed between B and C for B1 will be removed, and replaced with an edge between A and C for A2.
*/
@ -161,29 +161,29 @@ void ReplaceDownstreamNodeInput(Graph& graph, Node& node, int output_idx, Node&
/** Replace the input to a node with a NodeArg.
@remarks The replacement only updates the node's input definition and does not create any edges,
as typically this function is used to replace an input with an initializer or graph input
as typically this function is used to replace an input with an initializer or graph input
(there is no edge between an initializer or graph input and a Node).
*/
void ReplaceNodeInput(Node& target, int target_input_idx, NodeArg& new_input);
/** Add an input to a node with a NodeArg for an initializer or graph input.
@remarks target_input_idx must be the next input slot.
e.g. if a Node has 2 inputs, AddNodeInput can only add input 3 and not 4.
There is no edge between an initializer or graph input and a Node, so the replacement only updates the
@remarks target_input_idx must be the next input slot.
e.g. if a Node has 2 inputs, AddNodeInput can only add input 3 and not 4.
There is no edge between an initializer or graph input and a Node, so the replacement only updates the
node's input definition and does not create any new edges.
*/
void AddNodeInput(Node& target, int target_input_idx, NodeArg& new_input);
/** Finalize the fusion of second_node into first_node.
/** Finalize the fusion of second_node into first_node.
The output definitions and edges from the second_node are moved to first_node. second_node is deleted.
e.g. Conv + Add fusion fuses the 'Add' into the Conv.
*/
void FinalizeNodeFusion(Graph& graph, Node& first_node, Node& second_node);
/** Finalize the fusion of two or more nodes which are being replaced with a single node.
/** Finalize the fusion of two or more nodes which are being replaced with a single node.
The first and last entries in 'nodes' are assumed to be the first and last nodes in a chain of nodes being fused.
Conceptually multiple nodes are being combined into one, and post-fusion will produce output/s with the same names
Conceptually multiple nodes are being combined into one, and post-fusion will produce output/s with the same names
as the last node in 'nodes', and be connected to the same downstream nodes.
The input edges to the first node in 'nodes' will be moved to replacement_node. No other input edges are moved.
@ -229,7 +229,7 @@ struct EdgeEndToMatch {
@param edges_to_match has information of a sequence of adjacent edges in the path to be matched one by one.
@param result stores edges that are found.
@returns false when one edge has multiple candidates, or not all edges are found.
@remarks matching an EdgeEndToMatch might get multiple candidates in output edges.
@remarks matching an EdgeEndToMatch might get multiple candidates in output edges.
When such case is encountered, this function will return false. This is by design to reduce complexity.
Here is an example graph:
Add
@ -237,7 +237,7 @@ struct EdgeEndToMatch {
Mul Mul
\ /
Sub
For example, you want to match path from top to bottom: Add-->Mul-->Sub.
For example, you want to match path from top to bottom: Add-->Mul-->Sub.
When matching the first edge Add-->Mul, the algorithm found two matches.
Then it returns false, and output a warning log entry.
@ -252,9 +252,15 @@ bool FindPath(Graph& graph, const Node& node, bool is_input_edge, const std::vec
/**
* Remove nodes with only one output edge using bottom-up bfs traversal.
* @param node: The node to start with.
* @param node: The node to start with.
* @returns true if there is one or more node(s) removed by this function. Otherwise return false.
*/
bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& node);
/** Creates a mutable NodeArg owned by the graph with mirrored base_arg's TypeProto and name
* @param base_arg The NodeArg the newly created NodeArg is mirrored based off.
* @returns NodeArg reference that contains the same TypeProto info as base_arg with generated different names.
*/
NodeArg& CreateNodeArg(Graph& graph, const NodeArg& base_arg);
} // namespace graph_utils
} // namespace onnxruntime

View file

@ -222,5 +222,34 @@ TEST(PathTest, RelativePathFailure) {
#endif
}
TEST(PathTest, Concat) {
auto check_concat =
[](const std::string& a, const std::string& b, const std::string& expected_a) {
Path p_a{}, p_expected_a{};
ASSERT_STATUS_OK(Path::Parse(ToPathString(a), p_a));
ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_a), p_expected_a));
EXPECT_EQ(p_a.Concat(ToPathString(b)).ToPathString(), p_expected_a.ToPathString());
};
check_concat("/a/b", "c", "/a/bc");
check_concat("a/b", "cd", "a/bcd");
}
TEST(PathTest, ConcatIndex) {
auto check_concat_index =
[](const std::string& a, const int i, const std::string& expected_a) {
Path p_a{}, p_expected_a{};
ASSERT_STATUS_OK(Path::Parse(ToPathString(a), p_a));
ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_a), p_expected_a));
EXPECT_EQ(p_a.ConcatIndex(i).ToPathString(), p_expected_a.ToPathString());
};
check_concat_index("/a/b", 0, "/a/b0");
check_concat_index("a/b", 123, "a/b123");
check_concat_index("a/b", -1, "a/b-1");
}
} // namespace test
} // namespace onnxruntime

View file

@ -4,7 +4,10 @@
#include "orttraining/core/graph/pipeline_transformer.h"
#include <queue>
#include "core/graph/graph_utils.h"
using namespace onnxruntime::common;
using namespace onnxruntime::graph_utils;
namespace onnxruntime {
namespace training {
@ -78,17 +81,15 @@ std::vector<NodeArg*> FindBackwardLeafNodes(Graph& graph) {
// The newly created boolean scalar may be appended to signal_args. If
// signal_args is empty, the source of signal_args[i] would be tensor_args[i].
void ConvertTensorToBoolSignal(
Graph& graph,
const std::vector<NodeArg*>& tensor_args,
std::vector<NodeArg*>& signal_args) {
for (auto tensor_arg: tensor_args) {
Graph& graph,
const std::vector<NodeArg*>& tensor_args,
std::vector<NodeArg*>& signal_args) {
for (auto tensor_arg : tensor_args) {
// Declare the scalar signal this "tensor_arg" will be converted into.
auto signal_arg = &CreateTypedNodeArg(
graph,
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
"signal_" + tensor_arg->Name()
);
graph,
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
"signal_" + tensor_arg->Name());
// Add the new scalar to user-specified vector.
signal_args.push_back(signal_arg);
@ -98,37 +99,28 @@ void ConvertTensorToBoolSignal(
std::vector<NodeArg*> input_args{tensor_arg};
std::vector<NodeArg*> output_args{signal_arg};
graph.AddNode(
name,
"Group",
"",
input_args,
output_args,
nullptr,
kMSDomain);
name,
"Group",
"",
input_args,
output_args,
nullptr,
kMSDomain);
}
}
// Return mirror variables for node_arg with a different name.
NodeArg& CreateNodeArg(Graph& graph, const NodeArg* base_arg) {
const auto& new_name = graph.GenerateNodeArgName(base_arg->Name());
ONNX_NAMESPACE::TypeProto type_proto(*(base_arg->TypeAsProto()));
if (graph.GetNodeArg(new_name) != nullptr) {
ORT_THROW("Node with name ", new_name, " already exists.");
}
return graph.GetOrCreateNodeArg(new_name, &type_proto);
}
// Return mirror variables for node_args.
// The i-th output element mirrors node_args[i] but with a different name.
std::vector<NodeArg*> CreateMirrorNodeArgs(
Graph& graph,
const std::vector<NodeArg*>& node_args) {
Graph& graph,
const std::vector<NodeArg*>& node_args) {
// Declare output.
std::vector<NodeArg*> new_node_args;
for (auto& node_arg: node_args) {
for (auto& node_arg : node_args) {
// new_node_arg is a mirror variable of node_arg. They have the same type.
auto new_node_arg = &CreateNodeArg(graph, node_arg);
assert(node_arg);
auto new_node_arg = &CreateNodeArg(graph, *node_arg);
new_node_args.push_back(new_node_arg);
}
@ -138,33 +130,33 @@ std::vector<NodeArg*> CreateMirrorNodeArgs(
// Create a node with input schema [event, input1, input2, ..., inputN] and
// output schema [input1, input2, ..., inputN]
Node& CreateBottleneckNode(Graph& graph,
const std::string& op_type,
const std::string& op_name,
const std::string& description,
NodeArg* event,
std::vector<NodeArg*> input_node_args,
std::vector<NodeArg*> output_node_args) {
const std::string& op_type,
const std::string& op_name,
const std::string& description,
NodeArg* event,
std::vector<NodeArg*> input_node_args,
std::vector<NodeArg*> output_node_args) {
const auto name = graph.GenerateNodeName(op_name);
if (event) {
input_node_args.insert(input_node_args.begin(), event);
}
return graph.AddNode(
name,
op_type,
description,
input_node_args,
output_node_args,
nullptr /* assume all bottleneck node have no attributes */,
kMSDomain);
name,
op_type,
description,
input_node_args,
output_node_args,
nullptr /* assume all bottleneck node have no attributes */,
kMSDomain);
}
Node* AddBackwardRecord(Graph& graph,
Node* backward_send,
std::vector<std::string>& new_input_names,
std::vector<std::string>& new_output_names,
std::string &event_id_tensor_name,
std::string &output_tensor_name) {
std::string& event_id_tensor_name,
std::string& output_tensor_name) {
std::vector<NodeArg*> input_args;
AddNewNodeArg(graph, "backward_recorded_event_id", ONNX_NAMESPACE::TensorProto_DataType_INT64,
input_args, new_input_names);
@ -194,13 +186,16 @@ Node* AddBackwardRecord(Graph& graph,
// Optimizer will be added after applying pipeline transformer. To support partial graph evaluation,
// the added Record backward op will have its first passthrough input as output.
ORT_ENFORCE(input_args.size() >= 2, "RecordEvent backward op at least have two inputs.");
auto& new_output = CreateNodeArg(graph, input_args[1]); // the first input is signal, not passing through
// RecordEvent doesn't have optional input, so it cannot be nullptr.
assert(input_args[1]);
auto& new_output = CreateNodeArg(graph, *input_args[1]); // the first input is signal, not passing through
output_args.push_back(&new_output);
new_output_names.push_back(new_output.Name());
Node* record_node = &CreateBottleneckNode(
graph, "RecordEvent", "backward_record", "Backward pass", nullptr,
input_args, output_args);
graph, "RecordEvent", "backward_record", "Backward pass", nullptr,
input_args, output_args);
// First input argument is the recorded event ID tensor.
event_id_tensor_name = input_args.front()->Name();
@ -212,21 +207,19 @@ Node* AddBackwardRecord(Graph& graph,
}
Node* AddForwardWait(Graph& graph,
Node* /* forward_recv */,
std::vector<std::string>& new_input_names,
std::string& forward_waited_event_name,
std::string& output_tensor_name) {
Node* /* forward_recv */,
std::vector<std::string>& new_input_names,
std::string& forward_waited_event_name,
std::string& output_tensor_name) {
// Append old_input to input_args and return its pass-through value. Note that
// input_args and output_args are Wait's inputs and outputs, respectively.
auto update_wait_input_output = [&](NodeArg* old_input,
std::vector<NodeArg*>& input_args,
std::vector<NodeArg*>& output_args) -> NodeArg& {
assert(old_input);
input_args.push_back(old_input);
const auto& new_name = graph.GenerateNodeArgName(old_input->Name());
ONNX_NAMESPACE::TypeProto type_proto(*(old_input->TypeAsProto()));
auto& wait_output = graph.GetOrCreateNodeArg(new_name, &type_proto);
auto& wait_output = CreateNodeArg(graph, *old_input);
output_args.push_back(&wait_output);
return wait_output;
@ -257,8 +250,8 @@ Node* AddForwardWait(Graph& graph,
}
Node* wait_node = &CreateBottleneckNode(
graph, "WaitEvent", "backward_record", "", nullptr,
input_args, output_args);
graph, "WaitEvent", "backward_record", "", nullptr,
input_args, output_args);
forward_waited_event_name = input_args.front()->Name();
output_tensor_name = output_args.front()->Name();
@ -276,13 +269,15 @@ Status AddOrSkipForwardRecordBackwardWait(Graph& graph,
std::string& backward_waited_event_name,
std::string& forward_output_name,
std::string& backward_output_name) {
if (!forward_send != !backward_recv){
ORT_THROW("Graph requires either having both send forward node "
"and recv backword node, or none of them. Currently the graph "
"has send forward: ", forward_send, " and recv backward: ", backward_recv);
if (!forward_send != !backward_recv) {
ORT_THROW(
"Graph requires either having both send forward node "
"and recv backward node, or none of them. Currently the graph "
"has send forward: ",
forward_send, " and recv backward: ", backward_recv);
}
if (!forward_send && !backward_recv){
if (!forward_send && !backward_recv) {
// Last partition doesn't have send forwrad and recv backward. No insert
// needed.
return Status::OK();
@ -302,14 +297,16 @@ Status AddOrSkipForwardRecordBackwardWait(Graph& graph,
// Add send forward op's output as record op's input and output
for (auto& output : forward_send->MutableOutputDefs()) {
auto& new_output = CreateNodeArg(graph, output);
// send doesn't have optional output, so the node cannot be nullptr.
assert(output);
auto& new_output = CreateNodeArg(graph, *output);
output_args.push_back(&new_output);
input_args.push_back(output);
}
record_node = &CreateBottleneckNode(
graph, "RecordEvent", "forward_record", "", nullptr,
input_args, output_args);
graph, "RecordEvent", "forward_record", "", nullptr,
input_args, output_args);
forward_recorded_event_name = record_node->InputDefs()[0]->Name();
forward_output_name = record_node->OutputDefs()[0]->Name();
@ -327,13 +324,16 @@ Status AddOrSkipForwardRecordBackwardWait(Graph& graph,
std::end(record_node->MutableOutputDefs()));
auto& input = backward_recv->MutableInputDefs()[0];
auto& new_output = CreateNodeArg(graph, input);
// recv node doesn't have optional input, so the node cannot be nullptr.
assert(input);
auto& new_output = CreateNodeArg(graph, *input);
output_args.push_back(&new_output);
input = &new_output;
wait_node = &CreateBottleneckNode(
graph, "WaitEvent", "backward_wait", "Backward pass", nullptr,
input_args, output_args);
graph, "WaitEvent", "backward_wait", "Backward pass", nullptr,
input_args, output_args);
backward_waited_event_name = wait_node->InputDefs()[0]->Name();
backward_output_name = wait_node->OutputDefs()[0]->Name();
@ -352,8 +352,8 @@ void ReplaceNodeArgs(std::vector<Node*>& nodes,
ORT_ENFORCE(node_args[i]->Name() != new_node_args[i]->Name());
ORT_ENFORCE(node_args[i]->Type() == new_node_args[i]->Type());
for (auto& node: nodes) {
for (auto& node_arg: node->MutableInputDefs()) {
for (auto& node : nodes) {
for (auto& node_arg : node->MutableInputDefs()) {
// Only replace when node's input name matches node_args[i].
if (node_arg->Name().compare(node_args[i]->Name()) != 0) {
continue;
@ -371,11 +371,11 @@ void ReplaceNodeArgs(std::vector<Node*>& nodes,
}
std::string AddEventBeforeNode(
Graph& graph,
Node* node,
const std::string& event_op_type,
const std::string& event_op_name,
const std::string& event_id_name) {
Graph& graph,
Node* node,
const std::string& event_op_type,
const std::string& event_op_name,
const std::string& event_id_name) {
if (!node) {
// No event operator is be inserted, so we don't have its input event name.
return "";
@ -410,11 +410,11 @@ std::string AddEventBeforeNode(
}
std::string AddEventAfterNode(
Graph& graph,
Node* node,
const std::string& event_op_type,
const std::string& event_op_name,
const std::string& event_id_name) {
Graph& graph,
Node* node,
const std::string& event_op_type,
const std::string& event_op_name,
const std::string& event_id_name) {
if (!node) {
// No event operator is be inserted, so we don't have its input event name.
return "";
@ -431,7 +431,7 @@ std::string AddEventAfterNode(
for (size_t i = 0; i < node_args.size(); ++i) {
// Find consumer of "node"'s i-th output.
std::vector<Node*> consumer_nodes = graph.GetMutableConsumerNodes(
node_args[i]->Name());
node_args[i]->Name());
// Replace node_args[i] with new_node_args[i] in nodes.
ReplaceNodeArgs(consumer_nodes, {node_args[i]}, {new_node_args[i]});
}
@ -457,9 +457,9 @@ Status AddForwardWaitAfterRecv(
std::vector<std::string>& new_input_names,
std::string& event_name) {
event_name = AddEventAfterNode(
graph, comm_node,
"WaitEvent", "forward_wait_after_recv",
"forward_wait_after_recv_event_id");
graph, comm_node,
"WaitEvent", "forward_wait_after_recv",
"forward_wait_after_recv_event_id");
if (event_name.empty()) {
return Status::OK();
} else {
@ -474,9 +474,9 @@ Status AddForwardRecordBeforeSend(
std::vector<std::string>& new_input_names,
std::string& event_name) {
event_name = AddEventBeforeNode(
graph, comm_node,
"RecordEvent", "forward_record_before_send",
"forward_record_before_send_event_id");
graph, comm_node,
"RecordEvent", "forward_record_before_send",
"forward_record_before_send_event_id");
if (event_name.empty()) {
return Status::OK();
} else {
@ -491,9 +491,9 @@ Status AddBackwardWaitAfterRecv(
std::vector<std::string>& new_input_names,
std::string& event_name) {
event_name = AddEventAfterNode(
graph, comm_node,
"WaitEvent", "backward_wait_after_recv",
"backward_wait_after_recv_event_id");
graph, comm_node,
"WaitEvent", "backward_wait_after_recv",
"backward_wait_after_recv_event_id");
if (event_name.empty()) {
return Status::OK();
} else {
@ -508,9 +508,9 @@ Status AddBackwardRecordBeforeSend(
std::vector<std::string>& new_input_names,
std::string& event_name) {
event_name = AddEventBeforeNode(
graph, comm_node,
"RecordEvent", "backward_record_before_send",
"backward_record_before_send_event_id");
graph, comm_node,
"RecordEvent", "backward_record_before_send",
"backward_record_before_send_event_id");
if (event_name.empty()) {
return Status::OK();
} else {
@ -605,20 +605,20 @@ Status SetInputsOutputsAndResolve(Graph& graph,
// Record-2: Tell others that backward pass is done.
// Record-3: Tell others that backward result has been passed to another stage.
Status TransformGraphForPipeline(
Graph& graph,
const std::unordered_set<std::string>& weights_to_train,
std::string& forward_waited_event_name,
std::string& forward_recorded_event_name,
std::string& backward_waited_event_name,
std::string& backward_recorded_event_name,
std::string& forward_wait_output_name,
std::string& forward_record_output_name,
std::string& backward_wait_output_name,
std::string& backward_record_output_name,
std::string& forward_waited_event_after_recv_name,
std::string& forward_recorded_event_before_send_name,
std::string& backward_waited_event_after_recv_name,
std::string& backward_recorded_event_before_send_name) {
Graph& graph,
const std::unordered_set<std::string>& weights_to_train,
std::string& forward_waited_event_name,
std::string& forward_recorded_event_name,
std::string& backward_waited_event_name,
std::string& backward_recorded_event_name,
std::string& forward_wait_output_name,
std::string& forward_record_output_name,
std::string& backward_wait_output_name,
std::string& backward_record_output_name,
std::string& forward_waited_event_after_recv_name,
std::string& forward_recorded_event_before_send_name,
std::string& backward_waited_event_after_recv_name,
std::string& backward_recorded_event_before_send_name) {
// Declare nodes according to their topological order.
Node* forward_wait{nullptr};
Node* forward_send{nullptr};
@ -650,27 +650,27 @@ Status TransformGraphForPipeline(
std::vector<std::string> new_output_names;
backward_record = AddBackwardRecord(
graph,
backward_send,
new_input_names,
new_output_names,
backward_recorded_event_name,
backward_record_output_name);
graph,
backward_send,
new_input_names,
new_output_names,
backward_recorded_event_name,
backward_record_output_name);
forward_wait = AddForwardWait(
graph,
forward_recv,
new_input_names,
forward_waited_event_name,
forward_wait_output_name);
graph,
forward_recv,
new_input_names,
forward_waited_event_name,
forward_wait_output_name);
ORT_RETURN_IF_ERROR(AddOrSkipForwardRecordBackwardWait(
graph,
forward_send,
backward_recv,
new_input_names,
forward_recorded_event_name,
backward_waited_event_name,
forward_record_output_name,
backward_wait_output_name));
graph,
forward_send,
backward_recv,
new_input_names,
forward_recorded_event_name,
backward_waited_event_name,
forward_record_output_name,
backward_wait_output_name));
// Different stages have different patterns of Send & Recv.
// For different patterns, we add different WaitEvent and Record.
@ -694,11 +694,11 @@ Status TransformGraphForPipeline(
// One and only one of is_first_stage, is_middle_stage, and is_last_stage can be true.
const unsigned int stage_flag_sum = is_first_stage + is_middle_stage + is_last_stage;
ORT_RETURN_IF_NOT(stage_flag_sum == 1u,
"The processed graph should be classified into a stage, "
"but we see more than one true's in the following statements. ",
"Is first stage? ", is_first_stage, ". ",
"Is middle stage? ", is_middle_stage, ". ",
"Is last stage? ", is_last_stage, ".");
"The processed graph should be classified into a stage, "
"but we see more than one true's in the following statements. ",
"Is first stage? ", is_first_stage, ". ",
"Is middle stage? ", is_middle_stage, ". ",
"Is last stage? ", is_last_stage, ".");
// Now, we add Wait's in parentheses shown below.
// 1. First stage:
@ -713,19 +713,17 @@ Status TransformGraphForPipeline(
if (is_first_stage) {
// If first stage, insert after forward WaitEvent.
ORT_RETURN_IF_ERROR(AddForwardWaitAfterRecv(
graph,
forward_wait,
new_input_names,
forward_waited_event_after_recv_name
));
graph,
forward_wait,
new_input_names,
forward_waited_event_after_recv_name));
} else if (is_middle_stage || is_last_stage) {
// If middle stage or last stage, insert after forward Recv.
ORT_RETURN_IF_ERROR(AddForwardWaitAfterRecv(
graph,
forward_recv,
new_input_names,
forward_waited_event_after_recv_name
));
graph,
forward_recv,
new_input_names,
forward_waited_event_after_recv_name));
}
// Now, we add Record's in parentheses shown below.
@ -740,11 +738,10 @@ Status TransformGraphForPipeline(
// ----------------------> BW -> Record -> Send -> Record
if (is_first_stage || is_middle_stage) {
ORT_RETURN_IF_ERROR(AddForwardRecordBeforeSend(
graph,
forward_send,
new_input_names,
forward_recorded_event_before_send_name
));
graph,
forward_send,
new_input_names,
forward_recorded_event_before_send_name));
}
// Now, we add Wait's in parentheses shown below.
@ -759,11 +756,10 @@ Status TransformGraphForPipeline(
// ----------------------> BW -> Record -> Send -> Record
if (is_first_stage || is_middle_stage) {
ORT_RETURN_IF_ERROR(AddBackwardWaitAfterRecv(
graph,
backward_recv,
new_input_names,
backward_waited_event_after_recv_name
));
graph,
backward_recv,
new_input_names,
backward_waited_event_after_recv_name));
}
// Now, we add Record's in parentheses shown below.
@ -778,18 +774,16 @@ Status TransformGraphForPipeline(
// ----------------------> BW -> (Record) -> Send -> Record
if (is_first_stage) {
ORT_RETURN_IF_ERROR(AddBackwardRecordBeforeSend(
graph,
backward_record,
new_input_names,
backward_recorded_event_before_send_name
));
graph,
backward_record,
new_input_names,
backward_recorded_event_before_send_name));
} else if (is_middle_stage || is_last_stage) {
ORT_RETURN_IF_ERROR(AddBackwardRecordBeforeSend(
graph,
backward_send,
new_input_names,
backward_recorded_event_before_send_name
));
graph,
backward_send,
new_input_names,
backward_recorded_event_before_send_name));
}
ORT_RETURN_IF_ERROR(SetInputsOutputsAndResolve(graph, weights_to_train, new_input_names, new_output_names));
@ -801,11 +795,11 @@ Status TransformGraphForPipeline(
// It also cerates an initializer to store its value.
template <typename T>
void AddNewScalarNodeArgAndInitializer(Graph& graph,
const std::string& op_name,
onnx::TensorProto_DataType type,
T data,
std::vector<NodeArg*>& new_node_args,
std::vector<std::string>& new_names) {
const std::string& op_name,
onnx::TensorProto_DataType type,
T data,
std::vector<NodeArg*>& new_node_args,
std::vector<std::string>& new_names) {
AddNewNodeArg(graph, op_name, type, new_node_args, new_names);
ONNX_NAMESPACE::TensorProto proto_data;
@ -825,6 +819,218 @@ void AddNewScalarNodeArgAndInitializer(Graph& graph,
graph.AddInitializedTensor(proto_data);
}
// Given a node, this function finds all its connected nodes (consumer nodes and producer nodes) and
// connected inputs and outputs in the given graph, then adds them to the containers passed in.
Status FindAllConnectedNodes(Graph& graph,
Node* node,
std::vector<Node*>& connected_nodes,
std::set<NodeArg*>& connected_inputs,
std::set<NodeArg*>& connected_outputs
) {
assert(node);
ORT_THROW_IF_ERROR(node->ForEachMutableWithIndex(
node->MutableInputDefs(),
[&](NodeArg& node_arg, size_t /*index*/) {
if (graph.IsInputsIncludingInitializers(&node_arg) || graph.IsInitializedTensor(node_arg.Name())) {
connected_inputs.insert(&node_arg);
} else {
Node* producer_node = graph.GetMutableProducerNode(node_arg.Name());
if (producer_node == nullptr) {
// got nullptr as producer node. This could be because the input is a constant op which will be optimized
// away. Print out this information and continue.
// TODO: re-visit the different cases to see if there are other situations aside from constant ops.
LOGS_DEFAULT(WARNING) << "Cannot find producer node for node_arg: " << node_arg.Name() << ". Skipping this node.";
} else {
connected_nodes.push_back(producer_node);
}
}
return Status::OK();
}));
ORT_THROW_IF_ERROR(node->ForEachMutableWithIndex(
node->MutableOutputDefs(),
[&](NodeArg& node_arg, size_t /*index*/) {
if (!graph.IsOutput(&node_arg)) {
std::vector<Node*> consumer_nodes = graph.GetMutableConsumerNodes(node_arg.Name());
connected_nodes.insert(std::end(connected_nodes), consumer_nodes.begin(), consumer_nodes.end());
} else {
connected_outputs.insert(&node_arg);
}
return Status::OK();
}));
return Status::OK();
}
// PipelineStageNodeGroup groups nodes that share the same input initializer and belong to the same stage.
// It is used to distinguish other nodes that share the same input initializer but belong to
// other pipeline partitions after split.
struct PipelineStageNodeGroup {
const size_t stage_id;
// Vector of nodes that have the same initializer input and belong to the same stage. Noted that
// the consumer nodes of a particular initializer can be more than one, so we need a vector to store those
// nodes.
std::vector<Node*> nodes;
PipelineStageNodeGroup(const size_t stage, std::vector<Node*>& node_group) : stage_id(stage), nodes(std::move(node_group)){};
};
// This function passes through the given initializer across stages specified in node_groups[i].stage_id.
// This applies to the case when initializer is used in multiple stages, say stage a and stage b (a<b). We will
// keep the initializer in stage a and pass it down to b through the send nodes and recv nodes.
common::Status AddPassthroughInitializer(Graph& graph,
NodeArg* initializer,
const std::vector<PipelineStageNodeGroup>& node_groups,
const std::vector<Node*>& send_nodes,
const std::vector<Node*>& recv_nodes) {
assert(initializer);
ORT_ENFORCE(node_groups.size() >= 2, "Initializer ", initializer->Name(), " is not shared across stages.");
const size_t from_stage = node_groups.front().stage_id;
const size_t to_stage = node_groups.back().stage_id;
ORT_ENFORCE(from_stage < to_stage, "Pass through from_stage (", from_stage,
") is not less than the to_stage (", to_stage, ").");
auto dtype = initializer->TypeAsProto()->tensor_type().elem_type();
auto current_node_arg = initializer;
size_t node_group_index = 1;
for (auto i = from_stage; i < to_stage; ++i) {
// processing send node in cut i
auto& send_attributes = send_nodes[i]->GetMutableAttributes();
auto& send_element_types = send_attributes["element_types"];
send_element_types.add_ints(static_cast<int64_t>(dtype));
send_nodes[i]->MutableInputDefs().push_back(current_node_arg);
send_nodes[i]->MutableInputArgsCount().back()++;
// Create a new node_arg for the recv, as the new node_arg from recv node should possess a different id
// than the one in send
assert(current_node_arg);
current_node_arg = &CreateNodeArg(graph, *current_node_arg);
// process recv node in cut i
auto& recv_attributes = recv_nodes[i]->GetMutableAttributes();
auto& recv_element_types = recv_attributes["element_types"];
recv_element_types.add_ints(static_cast<int64_t>(dtype));
recv_nodes[i]->MutableOutputDefs().push_back(current_node_arg);
// update the consumer node's input if the node's group is not in the first partition
if (i > from_stage && node_groups[node_group_index].stage_id == (i + 1)) {
for (auto node : node_groups[node_group_index].nodes) {
for (auto& input_node : node->MutableInputDefs()) {
if (input_node == initializer) {
input_node = current_node_arg;
break;
}
}
}
node_group_index++;
}
}
ORT_ENFORCE(node_group_index == node_groups.size(), "Not all nodes are updated with new initializer.");
return Status::OK();
}
// Traverse the graph to find out all connected elements in the graph from start_node. The traverse treats the graph as an
// undirected graph.
void TraverseGraphWithConnectedElement(Graph& graph,
Node* start_node,
std::set<Node*>& visited_nodes,
std::set<NodeArg*>& visited_inputs,
std::set<NodeArg*>& visited_outputs) {
assert(start_node);
visited_nodes.clear();
visited_inputs.clear();
visited_outputs.clear();
std::queue<Node*> node_queue;
node_queue.push(start_node);
while (!node_queue.empty()) {
auto node = node_queue.front();
node_queue.pop();
if (visited_nodes.insert(node).second) {
std::vector<Node*> connected_nodes;
ORT_THROW_IF_ERROR(FindAllConnectedNodes(graph, node, connected_nodes, visited_inputs, visited_outputs));
for (auto n : connected_nodes) {
ORT_ENFORCE(n != nullptr, "Found nullptr in searching for connected nodes");
node_queue.push(n);
}
}
}
}
// If an initializer is shared across partitions, instead of creating a separate all_reduce op to
// sync with those tensors in selected partitions, we save only one copy of that initializer in
// the very first partition it appears, and pass that data down to all following partitions
// where this initializer is used.
common::Status HandleSharedInitializer(Graph& graph,
const std::vector<Node*>& send_nodes,
const std::vector<Node*>& recv_nodes) {
// Map a given initializer to all the partitions that its consumer nodes reside. The size of
// the mapped vector reflects how many partitions this initializer's consumer nodes distribute.
// If its size is greater than 1, it means this initializer is being used in more than one partition and
// we need to proceed those cases.
std::map<NodeArg*, std::vector<PipelineStageNodeGroup>> input_consumer_stage_map;
for (size_t stage = 0; stage <= send_nodes.size(); ++stage) {
std::set<Node*> visited_nodes;
std::set<NodeArg*> visited_inputs;
std::set<NodeArg*> visited_outputs;
// send_nodes[i] is the Send op in i-th stage's forward pass. recv_nodes[i] is the Recv in the (i+1)-th stage's
// forward pass. When not in last stage, traverse start from send node; otherwise start with the recv node as
// send node doesn't exist in last partition's forward pass.
Node* traverse_start_node = stage < send_nodes.size() ? send_nodes[stage] : recv_nodes.back();
TraverseGraphWithConnectedElement(graph,
traverse_start_node,
visited_nodes,
visited_inputs,
visited_outputs);
for (const auto input : visited_inputs) {
// If the node is an input instead of an initializer, continue
if (!graph.IsInitializedTensor(input->Name())){
continue;
}
// group all consumer nodes that shares the same input initializer in visited_consumer_nodes
std::vector<Node*> consumer_nodes = graph.GetMutableConsumerNodes(input->Name());
std::vector<Node*> visited_consumer_nodes;
for(auto consumer_node : consumer_nodes){
if (visited_nodes.count(consumer_node) != 0){
visited_consumer_nodes.push_back(consumer_node);
}
}
if (input_consumer_stage_map.count(input) == 0) {
input_consumer_stage_map[input] = std::vector<PipelineStageNodeGroup>{
PipelineStageNodeGroup(stage, visited_consumer_nodes)};
} else {
input_consumer_stage_map[input].push_back({stage, visited_consumer_nodes});
}
}
}
for (const auto& entry : input_consumer_stage_map) {
// If any initializer is shared, handle the logic of passing it from the first seen stage all
// the way to last seen stage.
if (entry.second.size() > 1) {
ORT_RETURN_IF_ERROR(AddPassthroughInitializer(graph,
entry.first, // initializer node_arg
entry.second, // initializer consumer node groups
send_nodes,
recv_nodes));
}
}
return Status::OK();
}
// split the graph into disconnected subgraph based on provided CutInfo
common::Status SplitGraph(Graph& graph,
std::vector<TrainingSession::TrainingConfiguration::CutInfo> split_edge_groups,
@ -869,30 +1075,30 @@ common::Status SplitGraph(Graph& graph,
auto cut_index_str = std::to_string(index);
// add input node_arg and initializer for send/recv
AddNewScalarNodeArgAndInitializer<bool>(graph,
"send_input_signal" + cut_index_str,
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
true, /* initializer data */
send_input_args,
new_input_names);
"send_input_signal" + cut_index_str,
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
true, /* initializer data */
send_input_args,
new_input_names);
AddNewScalarNodeArgAndInitializer<bool>(graph,
"recv_input_signal" + cut_index_str,
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
true, /* initializer data */
recv_input_args,
new_input_names);
"recv_input_signal" + cut_index_str,
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
true, /* initializer data */
recv_input_args,
new_input_names);
AddNewScalarNodeArgAndInitializer<size_t>(graph,
"send_dst_rank" + cut_index_str,
ONNX_NAMESPACE::TensorProto_DataType_INT64,
index + 1, /* initializer data */
send_input_args,
new_input_names);
"send_dst_rank" + cut_index_str,
ONNX_NAMESPACE::TensorProto_DataType_INT64,
index + 1, /* initializer data */
send_input_args,
new_input_names);
AddNewScalarNodeArgAndInitializer<size_t>(graph,
"recv_src_rank" + cut_index_str,
ONNX_NAMESPACE::TensorProto_DataType_INT64,
index, /* initializer data */
recv_input_args,
new_input_names);
"recv_src_rank" + cut_index_str,
ONNX_NAMESPACE::TensorProto_DataType_INT64,
index, /* initializer data */
recv_input_args,
new_input_names);
// add output node_arg for send/recv
AddNewNodeArg(graph, "send_output_signal" + cut_index_str, ONNX_NAMESPACE::TensorProto_DataType_BOOL,
send_output_args, new_output_names);
@ -948,6 +1154,7 @@ common::Status SplitGraph(Graph& graph,
if (exiting_updated_node_arg != updated_node_args.end()) {
updated_node_arg = exiting_updated_node_arg->second;
}
assert(updated_node_arg);
send_input_args.push_back(updated_node_arg);
@ -955,7 +1162,7 @@ common::Status SplitGraph(Graph& graph,
element_types.add_ints(static_cast<int64_t>(dtype));
auto& new_receive_output = CreateNodeArg(graph, updated_node_arg);
auto& new_receive_output = CreateNodeArg(graph, *updated_node_arg);
const auto old_shape = *(updated_node_arg->Shape());
new_receive_output.SetShape(old_shape);
recv_output_args.push_back(&new_receive_output);
@ -973,7 +1180,7 @@ common::Status SplitGraph(Graph& graph,
// deal with updating the consumer's input node_args
std::vector<Node*> consumer_nodes;
if (id.consumer_nodes.has_value()) {
for(auto& consumer_node_id : id.consumer_nodes.value()){
for (auto& consumer_node_id : id.consumer_nodes.value()) {
consumer_nodes.push_back(graph.GetMutableProducerNode(consumer_node_id));
}
} else {
@ -1019,68 +1226,17 @@ common::Status SplitGraph(Graph& graph,
return Status::OK();
}
Status FindAllConnectedNodes(Graph& graph,
const Node* node,
std::vector<const Node*>& connected_nodes,
std::set<const NodeArg*>& connected_inputs,
std::set<const NodeArg*>& connected_outputs) {
ORT_THROW_IF_ERROR(node->ForEachWithIndex(
node->InputDefs(),
[&](const NodeArg& node_arg, size_t /*index*/) {
if (graph.IsInputsIncludingInitializers(&node_arg)) {
connected_inputs.insert(&node_arg);
} else {
const Node* producer_node = graph.GetProducerNode(node_arg.Name());
if (producer_node == nullptr) {
// got nullptr as producer node. This could be because the input is a constant op which will be optimized
// away. Print out this information and continue.
LOGS_DEFAULT(WARNING) << "Cannot find producer node for node_arg: " << node_arg.Name() << ". Skipping this node.";
} else {
connected_nodes.push_back(producer_node);
}
}
return Status::OK();
}));
ORT_THROW_IF_ERROR(node->ForEachWithIndex(
node->OutputDefs(),
[&](const NodeArg& node_arg, size_t /*index*/) {
if (!graph.IsOutput(&node_arg)) {
std::vector<const Node*> consumer_nodes = graph.GetConsumerNodes(node_arg.Name());
connected_nodes.insert(std::end(connected_nodes), consumer_nodes.begin(), consumer_nodes.end());
} else {
connected_outputs.insert(&node_arg);
}
return Status::OK();
}));
return Status::OK();
}
// traverse the graph from start_node to get the set of nodes contains in this disconnected subgraph
common::Status GenerateSubgraph(Graph& graph, const Node* start_node) {
std::queue<const Node*> node_queue;
node_queue.push(start_node);
std::set<const Node*> visited_nodes;
std::set<const NodeArg*> visited_inputs;
std::set<const NodeArg*> visited_outputs;
common::Status GenerateSubgraph(Graph& graph, Node* start_node) {
assert(start_node);
std::set<Node*> visited_nodes;
std::set<NodeArg*> visited_inputs;
std::set<NodeArg*> visited_outputs;
// BFS graph traverse
while (!node_queue.empty()) {
auto node = node_queue.front();
node_queue.pop();
if (visited_nodes.count(node) == 0) {
visited_nodes.insert(node);
std::vector<const Node*> connected_nodes;
ORT_THROW_IF_ERROR(FindAllConnectedNodes(graph, node, connected_nodes, visited_inputs, visited_outputs));
TraverseGraphWithConnectedElement(graph, start_node,
visited_nodes, visited_inputs, visited_outputs);
for (auto n : connected_nodes) {
ORT_ENFORCE(n!=nullptr, "Found nullptr in searching for connected nodes");
node_queue.push(n);
}
}
}
std::set<NodeIndex> visited_node_index;
for (auto n : visited_nodes) {
visited_node_index.insert(n->Index());
@ -1090,8 +1246,8 @@ common::Status GenerateSubgraph(Graph& graph, const Node* start_node) {
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
// reverse iterate the nodes in tolopogical order, and delete those not visited
for (auto it = node_topology_list.rbegin(); it != node_topology_list.rend(); it++){
if (visited_node_index.count(*it)==0){
for (auto it = node_topology_list.rbegin(); it != node_topology_list.rend(); it++) {
if (visited_node_index.count(*it) == 0) {
graph.RemoveNode(*it);
}
}
@ -1131,18 +1287,60 @@ Status ApplyPipelinePartitionToMainGraph(
std::vector<Node *> send_nodes, recv_nodes;
send_nodes.reserve(split_count);
recv_nodes.reserve(split_count);
// Split the graph by cutting edges specified in cut_info. After this function, the graph will be
// composed of several disconnected partitions.
ORT_RETURN_IF_ERROR(SplitGraph(graph, cut_info, send_nodes, recv_nodes));
if (send_nodes.size() != split_count || recv_nodes.size() != split_count) {
ORT_THROW("Split error: not all cut has Send and Recv inserted. Send node count: ",
send_nodes.size(), ", Recv node count: ", recv_nodes.size(), ", split count: ", split_count);
send_nodes.size(), ", Recv node count: ", recv_nodes.size(), ", split count: ", split_count);
}
// Check to see if there are any initializers that is being shared between different partitions. If there
// is, keep the initializer in the first seen partition and have it pass through by send/recv to the others.
ORT_RETURN_IF_ERROR(HandleSharedInitializer(graph, send_nodes, recv_nodes));
// Now remove the partitions that are not tie to the current pipeline stage and generate the sub-graph.
if (pipeline_stage_id < split_count) {
ORT_RETURN_IF_ERROR(GenerateSubgraph(graph, send_nodes[pipeline_stage_id]));
} else {
ORT_RETURN_IF_ERROR(GenerateSubgraph(graph, recv_nodes.back()));
}
// Post check to ensure the curent partition is correct and matches with Send/Recv nodes inserted during split.
Node* send_node{nullptr};
Node* recv_node{nullptr};
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Send") {
send_node = &node;
} else if (node.OpType() == "Recv") {
recv_node = &node;
}
}
if (pipeline_stage_id == 0){
// For the first stage, there should be no recv node, and the send node contained in graph should match the first
// send_node inserted during split.
ORT_ENFORCE(recv_node == nullptr, "Error: first stage contains Recv node in forward pass.");
ORT_ENFORCE(send_node == send_nodes[0],
"Error: first stage doesn't contain the right Send node. Possibly CutInfo data is wrong.");
}
else if (pipeline_stage_id == split_count){
// For the last stage, there should be no send node, and the recv node contained in graph should match the last
// recv_node inserted during split.
ORT_ENFORCE(recv_node == recv_nodes.back(),
"Error: last stage doesn't contain the right Recv node. Possibly CutInfo data is wrong.");
ORT_ENFORCE(send_node == nullptr, "Error: last stage contains Send node in forward pass.");
} else {
// For stages in the middle, i-th stage should contain recv node that matches the (i-1)-th inserted recv node, and the i-th
// inserted send node.
ORT_ENFORCE(recv_node == recv_nodes[pipeline_stage_id - 1],
"Error: stage ", pipeline_stage_id, " doesn't contain the right Recv node. Possibly CutInfo data is wrong.");
ORT_ENFORCE(send_node == send_nodes[pipeline_stage_id],
"Error: stage ", pipeline_stage_id, " doesn't contain the right Send node. Possibly CutInfo data is wrong.");
}
return Status::OK();
}

View file

@ -1117,7 +1117,114 @@ void RetrieveSendRecvOperators(
}
}
TEST(GradientGraphBuilderTest, PipelineOnlinePartition) {
PathString GenerateFileNameWithIndex(const PathString& base_str, int index, const PathString& file_suffix) {
Path p;
ORT_ENFORCE(Path::Parse(base_str, p).IsOK());
return p.ConcatIndex(index).Concat(file_suffix).ToPathString();
}
TEST(GradientGraphBuilderTest, PipelineOnlinePartition_bert_tiny) {
const auto model_path = ORT_TSTR("testdata/bert_toy_optimized.onnx");
const size_t total_partition_count = 3;
TrainingSession::TrainingConfiguration::PipelineConfiguration pipe{};
pipe.do_partition = true;
// cut model in 3 partitions
TrainingSession::TrainingConfiguration::CutInfo cut0 = {
onnxruntime::training::TrainingSession::TrainingConfiguration::CutEdge("326"),
onnxruntime::training::TrainingSession::TrainingConfiguration::CutEdge("103", {"413", "529"})};
TrainingSession::TrainingConfiguration::CutInfo cut1 = {
onnxruntime::training::TrainingSession::TrainingConfiguration::CutEdge("558"),
onnxruntime::training::TrainingSession::TrainingConfiguration::CutEdge("103", {"645"})};
pipe.cut_list.emplace_back(cut0);
pipe.cut_list.emplace_back(cut1);
TrainingSession::TrainingConfiguration::MixedPrecisionConfiguration mixed_precision_config{};
mixed_precision_config.use_fp16_initializers = true;
// 2 test variations - full precision and mixed precision
const std::vector<bool> test_with_fp32{true, false};
for (auto is_fp32 : test_with_fp32) {
// graph is partitioned into 3 parts.
for (int i = 0; i < static_cast<int>(total_partition_count); ++i) {
PathString output_file = GenerateFileNameWithIndex(ORT_TSTR("pipeline_partition_"), i, ORT_TSTR("_back.onnx"));
auto config = MakeBasicTrainingConfig();
if (i == static_cast<int>(total_partition_count - 1)) {
config.loss_function_config = TrainingSession::TrainingConfiguration::LossFunctionConfiguration{};
config.loss_function_config.value().loss_function_info =
LossFunctionInfo(OpDef("BertLoss", kOnnxDomain),
"total_loss",
{/*prediction_masked_lm*/ "prediction_scores",
/*prediction_next_sentence*/ "seq_relationship_score",
/*masked_lm_positions*/ "masked_lm_positions",
/*masked_lm_ids*/ "masked_lm_ids",
/*masked_lm_weights*/ "masked_lm_weights",
/*next_sentence_labels*/ "next_sentence_labels",
/*mlm_loss*/ "mlm_loss",
/*nsp_loss*/ "nsp_loss"});
}
// Add weight_names_to_not_train to avoid generating backward graph on those tensor
config.weight_names_to_not_train = {
"position_01", // Slice's dat input
"op_min_ends_expand_10", //op_min_ends_expand_10
};
config.pipeline_config = pipe;
config.distributed_config.world_rank = i;
config.distributed_config.world_size = total_partition_count;
config.distributed_config.local_rank = i;
config.distributed_config.local_size = total_partition_count;
config.distributed_config.data_parallel_size = 1;
config.distributed_config.horizontal_parallel_size = 1;
config.distributed_config.pipeline_parallel_size = total_partition_count;
config.model_with_training_graph_path = output_file;
if (!is_fp32) {
config.mixed_precision_config = mixed_precision_config;
}
PathString backprop_model_file;
Status status = BuildBackPropGraph(model_path, config, backprop_model_file);
ASSERT_TRUE(status.IsOK()) << status << " (is_fp32 = " << is_fp32 << ", stage = " << i << ").\n";
// Skip the re-load for mixed-precision case. This model contains grad op that has function body,
// which takes a const tensor input. Const cast for input in function body won't be saved in the output
// model so reload will run into error.
// For the purpose of testing mixed-precision, BuildBackPropGraph above will be sufficient to verify the
// partition logic and validate the graph.
if (is_fp32) {
std::shared_ptr<Model> model;
// Ensure the partitioned model load.
status = Model::Load(backprop_model_file, model, nullptr, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(status.IsOK()) << status << " (is_fp32 = " << is_fp32 << ", stage = " << i << ").\n";
// verify the first stage contains word embedding as input and the last stage doesn't
auto model_proto = model->ToProto();
const auto& graph_proto = model_proto.graph();
bool found_word_embedding = false;
for (auto& tensor : graph_proto.initializer()) {
if (tensor.name() == "bert.embeddings.word_embeddings.weight") {
found_word_embedding = true;
}
}
if (i == 0) {
ASSERT_TRUE(found_word_embedding) << " (is_fp32 = " << is_fp32 << ", stage = " << i << ").\n";
} else {
ASSERT_FALSE(found_word_embedding) << " (is_fp32 = " << is_fp32 << ", stage = " << i << ").\n";
}
}
}
}
}
TEST(GradientGraphBuilderTest, PipelineOnlinePartition_MLP) {
auto model_uri = ORIGINAL_MODEL_PATH;
TrainingSession::TrainingConfiguration::PipelineConfiguration pipe{};
@ -1137,14 +1244,10 @@ TEST(GradientGraphBuilderTest, PipelineOnlinePartition) {
for(auto is_fp32 : test_with_fp32) {
// graph is partitioned into 3 parts.
for (int i = 0; i < 3; ++i) {
#ifdef _WIN32
auto surfix = std::to_wstring(i);
#else
auto surfix = std::to_string(i);
#endif
PathString output_file = ORT_TSTR("pipeline_partition_") + surfix + ORT_TSTR("_back.onnx");
PathString output_file = GenerateFileNameWithIndex(ORT_TSTR("pipeline_partition_"), i, ORT_TSTR("_back.onnx"));
auto config = MakeBasicTrainingConfig();
config.pipeline_config = pipe;
config.distributed_config.world_rank = i;
config.distributed_config.world_size = 3;
@ -1178,9 +1281,62 @@ TEST(GradientGraphBuilderTest, PipelineOnlinePartition) {
}
}
Status RunOnlinePartition(const std::vector<TrainingSession::TrainingConfiguration::CutInfo>& cut_list,
int pipeline_stage_size,
std::set<int> status_check_stages = {}) {
auto model_uri = ORIGINAL_MODEL_PATH;
TrainingSession::TrainingConfiguration::PipelineConfiguration pipe{};
pipe.do_partition = true;
pipe.cut_list = cut_list;
for (int i = 0; i < pipeline_stage_size; ++i) {
PathString output_file = GenerateFileNameWithIndex(ORT_TSTR("pipeline_partition_"), i, ORT_TSTR("_back.onnx"));
auto config = MakeBasicTrainingConfig();
config.pipeline_config = pipe;
config.distributed_config.world_rank = i;
config.distributed_config.world_size = pipeline_stage_size;
config.distributed_config.local_rank = i;
config.distributed_config.local_size = pipeline_stage_size;
config.distributed_config.data_parallel_size = 1;
config.distributed_config.horizontal_parallel_size = 1;
config.distributed_config.pipeline_parallel_size = pipeline_stage_size;
config.model_with_training_graph_path = output_file;
PathString backprop_model_file;
if (status_check_stages.count(i) > 0) {
auto status = BuildBackPropGraph(model_uri, config, backprop_model_file);
EXPECT_FALSE(status.IsOK());
} else {
EXPECT_THROW(BuildBackPropGraph(model_uri, config, backprop_model_file), OnnxRuntimeException);
}
}
return Status::OK();
}
TEST(GradientGraphBuilderTest, PipelineOnlinePartition_Invalid_Input) {
using CutEdge = TrainingSession::TrainingConfiguration::CutEdge;
using CutInfo = TrainingSession::TrainingConfiguration::CutInfo;
// Test with invalid cut edge
TrainingSession::TrainingConfiguration::CutInfo invalid_cut_edge = {TrainingSession::TrainingConfiguration::CutEdge("3")};
ASSERT_STATUS_OK(RunOnlinePartition(std::vector<TrainingSession::TrainingConfiguration::CutInfo>{invalid_cut_edge}, 2 /* pipeline_stage_size */));
// Test mis-matched cut list with stage size
TrainingSession::TrainingConfiguration::CutInfo cut_edge = {TrainingSession::TrainingConfiguration::CutEdge("T3")};
ASSERT_STATUS_OK(RunOnlinePartition(std::vector<TrainingSession::TrainingConfiguration::CutInfo>{cut_edge}, 3 /* pipeline_stage_size */));
// Test unordered cut_info list
CutInfo cut0 = {CutEdge("T3")};
CutInfo cut1 = {CutEdge("T6")};
ASSERT_STATUS_OK(RunOnlinePartition(std::vector<CutInfo>{cut1, cut0}, 3 /* pipeline_stage_size */, {0, 2} /* status_check_stages */));
}
// verify pipeline config can load and gradient graph can construct.
TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) {
PathString filename_base = ORT_TSTR("testdata/test_training_model_");
PathString filename_base = ORT_TSTR("testdata/test_training_model_");
auto load_and_check_gradient_graph = [](int stageIdx, PathString& input_file, PathString& output_file) {
auto config = MakeBasicTrainingConfig();
@ -1286,13 +1442,9 @@ TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) {
};
for (int i = 0; i < 3; ++i) {
#ifdef _WIN32
auto surfix = std::to_wstring(i);
#else
auto surfix = std::to_string(i);
#endif
PathString input_file = filename_base + surfix + ORT_TSTR(".onnx");
PathString output_file = filename_base + surfix + ORT_TSTR("_back.onnx");
PathString input_file = GenerateFileNameWithIndex(filename_base, i, ORT_TSTR(".onnx"));
PathString output_file = GenerateFileNameWithIndex(filename_base, i, ORT_TSTR("_back.onnx"));
load_and_check_gradient_graph(i, input_file, output_file);
}
}
@ -1358,12 +1510,7 @@ TEST(GradientGraphBuilderTest, TrainingSession_WithPipeline) {
std::vector<PathString> sub_model_files(num_subs);
for (size_t sub_id = 0; sub_id < num_subs; ++sub_id) {
#ifdef _WIN32
auto sub_id_str = std::to_wstring(sub_id);
#else
auto sub_id_str = std::to_string(sub_id);
#endif
sub_model_files[sub_id] = ORT_TSTR("sub_") + sub_id_str + ORT_TSTR(".onnx");
sub_model_files[sub_id] = GenerateFileNameWithIndex(ORT_TSTR("sub_"), sub_id, ORT_TSTR(".onnx"));
}
PipelineSplitter splitter;
@ -1383,12 +1530,7 @@ TEST(GradientGraphBuilderTest, TrainingSession_WithPipeline) {
for (size_t sub_id = 0; sub_id < num_subs; ++sub_id) {
auto& sub_sess = subs[sub_id];
sub_sess.so.enable_profiling = true;
#ifdef _WIN32
auto sub_id_str = std::to_wstring(sub_id);
#else
auto sub_id_str = std::to_string(sub_id);
#endif
sub_sess.so.profile_file_prefix = ORT_TSTR("pipeline") + sub_id_str;
sub_sess.so.profile_file_prefix = GenerateFileNameWithIndex(ORT_TSTR("pipeline"), static_cast<int>(sub_id), ORT_TSTR(""));
sub_sess.run_options.run_log_verbosity_level = sub_sess.so.session_log_verbosity_level;
sub_sess.run_options.run_tag = sub_sess.so.session_logid;