diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 10494836c4..b7f2747d5e 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -140,6 +140,25 @@ class Node { return common::Status::OK(); } + /** + Helper to iterate through the container returned by #MutableInputDefs() or #MutableOutputDefs() and call the provided function. + @param node_args Collection of NodeArgs returned by #MutableInputDefs() or #MutableOutputDefs() + @param func Function to call for each valid NodeArg in the node_args. The function is called with the NodeArg + and the index number in the container. + @returns common::Status with success or error information. + @remarks Returns immediately on error. + */ + static common::Status ForEachMutableWithIndex(std::vector& node_args, + std::function func) { + for (size_t index = 0; index < node_args.size(); ++index) { + auto arg = node_args[index]; + if (!arg->Exists()) + continue; + ORT_RETURN_IF_ERROR(func(*arg, index)); + } + return common::Status::OK(); + } + /** Gets the count of arguments for each of the Node's explicit inputs. */ const std::vector& InputArgCount() const noexcept { return definitions_.input_arg_count; } @@ -507,6 +526,9 @@ class Graph { /** Remove the initializer tensor with the provided name from the Graph. */ void RemoveInitializedTensor(const std::string& tensor_name); + /** Check if a given name is an initializer tensor's name in this graph. */ + bool IsInitializedTensor(const std::string& name) const; + /** Replaces the initializer tensor with the same name as the given initializer tensor. The replacement initializer tensor must have the same type and shape as the existing initializer tensor. @@ -635,15 +657,14 @@ class Graph { if (iter != node_args_.end()) { return *(iter->second); } - auto result = node_args_.insert(std::make_pair(name, onnxruntime::make_unique(name, p_arg_type))); return *(result.first->second); } - /** Generate a unique name.in this Graph for a NodeArg */ + /** Generate a unique name in this Graph for a NodeArg */ std::string GenerateNodeArgName(const std::string& base_name); - /** Generate a unique name.in this Graph for a Node */ + /** Generate a unique name in this Graph for a Node */ std::string GenerateNodeName(const std::string& base_name); /** Add a Node to this Graph. diff --git a/onnxruntime/core/common/path.cc b/onnxruntime/core/common/path.cc index ed06da52f7..0d60e884d1 100644 --- a/onnxruntime/core/common/path.cc +++ b/onnxruntime/core/common/path.cc @@ -253,6 +253,21 @@ Path& Path::Append(const Path& other) { return *this; } +Path& Path::Concat(const PathString& string) { + components_.back() += string; + return *this; +} + +Path& Path::ConcatIndex(const int index) { +#ifdef _WIN32 + auto index_str = std::to_wstring(index); +#else + auto index_str = std::to_string(index); +#endif + components_.back() += index_str; + return *this; +} + Status RelativePath(const Path& src, const Path& dst, Path& rel) { ORT_RETURN_IF_NOT( src.GetRootPathString() == dst.GetRootPathString(), diff --git a/onnxruntime/core/common/path.h b/onnxruntime/core/common/path.h index cf9fc3910f..9dac03b725 100644 --- a/onnxruntime/core/common/path.h +++ b/onnxruntime/core/common/path.h @@ -61,6 +61,19 @@ class Path { * The algorithm should model that of std::filesystem::path::append(). */ Path& Append(const Path& other); + + /** + * Concatenates the current path and the argument string. + * Unlike with Append() or operator/=, additional directory separators are never introduced. + */ + Path& Concat(const PathString& string); + + /** + * Concatenates an index by the end of current path. + * Similar to Concat() except the argument is an index. + */ + Path& ConcatIndex(const int index); + /** Equivalent to this->Append(other). */ Path& operator/=(const Path& other) { return Append(other); diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 08e001eb8a..3a7cbc68c9 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2331,6 +2331,10 @@ static void RemoveRepeatedFieldEntry(T& repeated_field, const TIter& entry_to_re } } +bool Graph::IsInitializedTensor(const std::string& name) const { + return name_to_initial_tensor_.count(name) > 0; +} + void Graph::RemoveInitializedTensor(const std::string& tensor_name) { bool found = false; auto iter = name_to_initial_tensor_.find(tensor_name); diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index c2a6a7b0dc..4065fa67a5 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -181,7 +181,7 @@ static void RemoveGraphEdges(Graph& graph, const std::vector& edges) } /** Given a graph, a list of edges, and a NodeArg name, checks if each of the edges provides an implicit input - to a subgraph. If so, it checks if there is no clash of the given NodeArg name in each of the subgraphs. + to a subgraph. If so, it checks if there is no clash of the given NodeArg name in each of the subgraphs. This is important when removing a node with this NodeArg as input. */ static bool CanUpdateImplicitInputNameInSubgraphs(const Graph& graph, const std::vector& output_edges, @@ -321,7 +321,7 @@ const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const s return iter == attrs.end() ? nullptr : &iter->second; } -/** Checks for nodes with >= 1 outputs, if only one of the outputs is input to downstream Operators. +/** Checks for nodes with >= 1 outputs, if only one of the outputs is input to downstream Operators. Returns the name of the single used output in output_name. */ static bool IsOnlyOneOutputUsed(const Graph& graph, const Node& node, const std::string*& output_name) { const int unassigned = -1; @@ -807,5 +807,10 @@ bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& start_node) { } return true; } + +NodeArg& CreateNodeArg(Graph& graph, const NodeArg& base_arg) { + return graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(base_arg.Name()), base_arg.TypeAsProto()); +} + } // namespace graph_utils } // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h index 546619f9f1..b23da03c7f 100644 --- a/onnxruntime/core/graph/graph_utils.h +++ b/onnxruntime/core/graph/graph_utils.h @@ -34,27 +34,27 @@ bool IsOutputUsed(const Node& node, int index); /** Returns true if the graph has the given input.*/ bool IsGraphInput(const Graph& graph, const NodeArg* input); -/** returns true if 'name' is an initializer in 'graph', or an ancestor graph if check_outer_scope is true. +/** returns true if 'name' is an initializer in 'graph', or an ancestor graph if check_outer_scope is true. @param check_outer_scope If true and 'graph' is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'. */ bool IsInitializer(const Graph& graph, const std::string& name, bool check_outer_scope); -/** returns true if 'name' is an initializer, and is constant and cannot be overridden at runtime. +/** returns true if 'name' is an initializer, and is constant and cannot be overridden at runtime. @param check_outer_scope If true and 'graph' is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'. */ bool IsConstantInitializer(const Graph& graph, const std::string& name, bool check_outer_scope = true); -/** returns the initializer's TensorProto if 'name' is an initializer, is constant and +/** returns the initializer's TensorProto if 'name' is an initializer, is constant and cannot be overridden at runtime. If the initializer is not found or is not constant, a nullptr is returned. @param check_outer_scope If true and the graph is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'. */ const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const Graph& graph, const std::string& name, bool check_outer_scope = true); -/** Add a new initializer to 'graph'. +/** Add a new initializer to 'graph'. Checks that new_initializer does not already exist in 'graph' before adding it. -@returns The NodeArg for the new initializer. -@remarks No matching graph input is created, so the initializer will be constant. +@returns The NodeArg for the new initializer. +@remarks No matching graph input is created, so the initializer will be constant. */ NodeArg& AddInitializer(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer); @@ -127,7 +127,7 @@ bool RemoveNode(Graph& graph, Node& node); /** Tests if we can remove a node and replace its output with an initializer. Conditions: - - Only one of the node's outputs is used by downstream operators or as a graph output + - Only one of the node's outputs is used by downstream operators or as a graph output - multiple edges for the single used output are allowed - If the node produces a graph output the initializer_name must be the same as the node's output name - otherwise the required graph output will not be produced @@ -140,20 +140,20 @@ bool CanReplaceNodeWithInitializer(const Graph& graph, const Node& node, const s See CanReplaceNodeWithInitializer for the conditions that must be satisfied in order to remove the node.*/ bool ReplaceNodeWithInitializer(Graph& graph, Node& node, NodeArg& replacement); -/** Removes all output edges from the given Node of the Graph. +/** Removes all output edges from the given Node of the Graph. This should probably be elevated to the Graph API eventually. */ size_t RemoveNodeOutputEdges(Graph& graph, Node& node); -/** Replaces the input to nodes that are downstream from 'node', which was being provided by an output of 'node', +/** Replaces the input to nodes that are downstream from 'node', which was being provided by an output of 'node', with an output from a different node. Moves the output edges from 'node' for 'output_idx' to the replacement node. @param replacement The node providing the replacement output. -@param replacement_output_idx The index of the output from 'replacement' to use. +@param replacement_output_idx The index of the output from 'replacement' to use. -e.g. Node A produces outputs A1 and A2. - Node B consumes A2 (edge between A and B for A2) and produces B1. +e.g. Node A produces outputs A1 and A2. + Node B consumes A2 (edge between A and B for A2) and produces B1. Node C consumes B1 (edge between B and C for B1). - - If Node B was determined to not be needed, you would call ReplaceDownstreamNodeInput(graph, B, 0, A, 1) + + If Node B was determined to not be needed, you would call ReplaceDownstreamNodeInput(graph, B, 0, A, 1) to replace B1 (output index 0 for node B) with A2 (output index 1 for node A) as input to the downstream node C. The edge that existed between B and C for B1 will be removed, and replaced with an edge between A and C for A2. */ @@ -161,29 +161,29 @@ void ReplaceDownstreamNodeInput(Graph& graph, Node& node, int output_idx, Node& /** Replace the input to a node with a NodeArg. @remarks The replacement only updates the node's input definition and does not create any edges, - as typically this function is used to replace an input with an initializer or graph input + as typically this function is used to replace an input with an initializer or graph input (there is no edge between an initializer or graph input and a Node). */ void ReplaceNodeInput(Node& target, int target_input_idx, NodeArg& new_input); /** Add an input to a node with a NodeArg for an initializer or graph input. -@remarks target_input_idx must be the next input slot. - e.g. if a Node has 2 inputs, AddNodeInput can only add input 3 and not 4. - There is no edge between an initializer or graph input and a Node, so the replacement only updates the +@remarks target_input_idx must be the next input slot. + e.g. if a Node has 2 inputs, AddNodeInput can only add input 3 and not 4. + There is no edge between an initializer or graph input and a Node, so the replacement only updates the node's input definition and does not create any new edges. */ void AddNodeInput(Node& target, int target_input_idx, NodeArg& new_input); -/** Finalize the fusion of second_node into first_node. +/** Finalize the fusion of second_node into first_node. The output definitions and edges from the second_node are moved to first_node. second_node is deleted. e.g. Conv + Add fusion fuses the 'Add' into the Conv. */ void FinalizeNodeFusion(Graph& graph, Node& first_node, Node& second_node); -/** Finalize the fusion of two or more nodes which are being replaced with a single node. +/** Finalize the fusion of two or more nodes which are being replaced with a single node. The first and last entries in 'nodes' are assumed to be the first and last nodes in a chain of nodes being fused. - Conceptually multiple nodes are being combined into one, and post-fusion will produce output/s with the same names + Conceptually multiple nodes are being combined into one, and post-fusion will produce output/s with the same names as the last node in 'nodes', and be connected to the same downstream nodes. The input edges to the first node in 'nodes' will be moved to replacement_node. No other input edges are moved. @@ -229,7 +229,7 @@ struct EdgeEndToMatch { @param edges_to_match has information of a sequence of adjacent edges in the path to be matched one by one. @param result stores edges that are found. @returns false when one edge has multiple candidates, or not all edges are found. -@remarks matching an EdgeEndToMatch might get multiple candidates in output edges. +@remarks matching an EdgeEndToMatch might get multiple candidates in output edges. When such case is encountered, this function will return false. This is by design to reduce complexity. Here is an example graph: Add @@ -237,7 +237,7 @@ struct EdgeEndToMatch { Mul Mul \ / Sub - For example, you want to match path from top to bottom: Add-->Mul-->Sub. + For example, you want to match path from top to bottom: Add-->Mul-->Sub. When matching the first edge Add-->Mul, the algorithm found two matches. Then it returns false, and output a warning log entry. @@ -252,9 +252,15 @@ bool FindPath(Graph& graph, const Node& node, bool is_input_edge, const std::vec /** * Remove nodes with only one output edge using bottom-up bfs traversal. - * @param node: The node to start with. + * @param node: The node to start with. * @returns true if there is one or more node(s) removed by this function. Otherwise return false. */ bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& node); + +/** Creates a mutable NodeArg owned by the graph with mirrored base_arg's TypeProto and name + * @param base_arg The NodeArg the newly created NodeArg is mirrored based off. + * @returns NodeArg reference that contains the same TypeProto info as base_arg with generated different names. +*/ +NodeArg& CreateNodeArg(Graph& graph, const NodeArg& base_arg); } // namespace graph_utils } // namespace onnxruntime diff --git a/onnxruntime/test/common/path_test.cc b/onnxruntime/test/common/path_test.cc index 8005b2cc83..ea8908affc 100644 --- a/onnxruntime/test/common/path_test.cc +++ b/onnxruntime/test/common/path_test.cc @@ -222,5 +222,34 @@ TEST(PathTest, RelativePathFailure) { #endif } +TEST(PathTest, Concat) { + auto check_concat = + [](const std::string& a, const std::string& b, const std::string& expected_a) { + Path p_a{}, p_expected_a{}; + ASSERT_STATUS_OK(Path::Parse(ToPathString(a), p_a)); + ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_a), p_expected_a)); + + EXPECT_EQ(p_a.Concat(ToPathString(b)).ToPathString(), p_expected_a.ToPathString()); + }; + + check_concat("/a/b", "c", "/a/bc"); + check_concat("a/b", "cd", "a/bcd"); +} + +TEST(PathTest, ConcatIndex) { + auto check_concat_index = + [](const std::string& a, const int i, const std::string& expected_a) { + Path p_a{}, p_expected_a{}; + ASSERT_STATUS_OK(Path::Parse(ToPathString(a), p_a)); + ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_a), p_expected_a)); + + EXPECT_EQ(p_a.ConcatIndex(i).ToPathString(), p_expected_a.ToPathString()); + }; + + check_concat_index("/a/b", 0, "/a/b0"); + check_concat_index("a/b", 123, "a/b123"); + check_concat_index("a/b", -1, "a/b-1"); +} + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/pipeline_transformer.cc b/orttraining/orttraining/core/graph/pipeline_transformer.cc index f67603da7f..b0c8959824 100644 --- a/orttraining/orttraining/core/graph/pipeline_transformer.cc +++ b/orttraining/orttraining/core/graph/pipeline_transformer.cc @@ -4,7 +4,10 @@ #include "orttraining/core/graph/pipeline_transformer.h" #include +#include "core/graph/graph_utils.h" + using namespace onnxruntime::common; +using namespace onnxruntime::graph_utils; namespace onnxruntime { namespace training { @@ -78,17 +81,15 @@ std::vector FindBackwardLeafNodes(Graph& graph) { // The newly created boolean scalar may be appended to signal_args. If // signal_args is empty, the source of signal_args[i] would be tensor_args[i]. void ConvertTensorToBoolSignal( - Graph& graph, - const std::vector& tensor_args, - std::vector& signal_args) { - - for (auto tensor_arg: tensor_args) { + Graph& graph, + const std::vector& tensor_args, + std::vector& signal_args) { + for (auto tensor_arg : tensor_args) { // Declare the scalar signal this "tensor_arg" will be converted into. auto signal_arg = &CreateTypedNodeArg( - graph, - ONNX_NAMESPACE::TensorProto_DataType_BOOL, - "signal_" + tensor_arg->Name() - ); + graph, + ONNX_NAMESPACE::TensorProto_DataType_BOOL, + "signal_" + tensor_arg->Name()); // Add the new scalar to user-specified vector. signal_args.push_back(signal_arg); @@ -98,37 +99,28 @@ void ConvertTensorToBoolSignal( std::vector input_args{tensor_arg}; std::vector output_args{signal_arg}; graph.AddNode( - name, - "Group", - "", - input_args, - output_args, - nullptr, - kMSDomain); + name, + "Group", + "", + input_args, + output_args, + nullptr, + kMSDomain); } } -// Return mirror variables for node_arg with a different name. -NodeArg& CreateNodeArg(Graph& graph, const NodeArg* base_arg) { - const auto& new_name = graph.GenerateNodeArgName(base_arg->Name()); - ONNX_NAMESPACE::TypeProto type_proto(*(base_arg->TypeAsProto())); - if (graph.GetNodeArg(new_name) != nullptr) { - ORT_THROW("Node with name ", new_name, " already exists."); - } - return graph.GetOrCreateNodeArg(new_name, &type_proto); -} - // Return mirror variables for node_args. // The i-th output element mirrors node_args[i] but with a different name. std::vector CreateMirrorNodeArgs( - Graph& graph, - const std::vector& node_args) { + Graph& graph, + const std::vector& node_args) { // Declare output. std::vector new_node_args; - for (auto& node_arg: node_args) { + for (auto& node_arg : node_args) { // new_node_arg is a mirror variable of node_arg. They have the same type. - auto new_node_arg = &CreateNodeArg(graph, node_arg); + assert(node_arg); + auto new_node_arg = &CreateNodeArg(graph, *node_arg); new_node_args.push_back(new_node_arg); } @@ -138,33 +130,33 @@ std::vector CreateMirrorNodeArgs( // Create a node with input schema [event, input1, input2, ..., inputN] and // output schema [input1, input2, ..., inputN] Node& CreateBottleneckNode(Graph& graph, - const std::string& op_type, - const std::string& op_name, - const std::string& description, - NodeArg* event, - std::vector input_node_args, - std::vector output_node_args) { + const std::string& op_type, + const std::string& op_name, + const std::string& description, + NodeArg* event, + std::vector input_node_args, + std::vector output_node_args) { const auto name = graph.GenerateNodeName(op_name); if (event) { input_node_args.insert(input_node_args.begin(), event); } return graph.AddNode( - name, - op_type, - description, - input_node_args, - output_node_args, - nullptr /* assume all bottleneck node have no attributes */, - kMSDomain); + name, + op_type, + description, + input_node_args, + output_node_args, + nullptr /* assume all bottleneck node have no attributes */, + kMSDomain); } Node* AddBackwardRecord(Graph& graph, Node* backward_send, std::vector& new_input_names, std::vector& new_output_names, - std::string &event_id_tensor_name, - std::string &output_tensor_name) { + std::string& event_id_tensor_name, + std::string& output_tensor_name) { std::vector input_args; AddNewNodeArg(graph, "backward_recorded_event_id", ONNX_NAMESPACE::TensorProto_DataType_INT64, input_args, new_input_names); @@ -194,13 +186,16 @@ Node* AddBackwardRecord(Graph& graph, // Optimizer will be added after applying pipeline transformer. To support partial graph evaluation, // the added Record backward op will have its first passthrough input as output. ORT_ENFORCE(input_args.size() >= 2, "RecordEvent backward op at least have two inputs."); - auto& new_output = CreateNodeArg(graph, input_args[1]); // the first input is signal, not passing through + + // RecordEvent doesn't have optional input, so it cannot be nullptr. + assert(input_args[1]); + auto& new_output = CreateNodeArg(graph, *input_args[1]); // the first input is signal, not passing through output_args.push_back(&new_output); new_output_names.push_back(new_output.Name()); Node* record_node = &CreateBottleneckNode( - graph, "RecordEvent", "backward_record", "Backward pass", nullptr, - input_args, output_args); + graph, "RecordEvent", "backward_record", "Backward pass", nullptr, + input_args, output_args); // First input argument is the recorded event ID tensor. event_id_tensor_name = input_args.front()->Name(); @@ -212,21 +207,19 @@ Node* AddBackwardRecord(Graph& graph, } Node* AddForwardWait(Graph& graph, - Node* /* forward_recv */, - std::vector& new_input_names, - std::string& forward_waited_event_name, - std::string& output_tensor_name) { + Node* /* forward_recv */, + std::vector& new_input_names, + std::string& forward_waited_event_name, + std::string& output_tensor_name) { // Append old_input to input_args and return its pass-through value. Note that // input_args and output_args are Wait's inputs and outputs, respectively. auto update_wait_input_output = [&](NodeArg* old_input, std::vector& input_args, std::vector& output_args) -> NodeArg& { + assert(old_input); input_args.push_back(old_input); - const auto& new_name = graph.GenerateNodeArgName(old_input->Name()); - ONNX_NAMESPACE::TypeProto type_proto(*(old_input->TypeAsProto())); - - auto& wait_output = graph.GetOrCreateNodeArg(new_name, &type_proto); + auto& wait_output = CreateNodeArg(graph, *old_input); output_args.push_back(&wait_output); return wait_output; @@ -257,8 +250,8 @@ Node* AddForwardWait(Graph& graph, } Node* wait_node = &CreateBottleneckNode( - graph, "WaitEvent", "backward_record", "", nullptr, - input_args, output_args); + graph, "WaitEvent", "backward_record", "", nullptr, + input_args, output_args); forward_waited_event_name = input_args.front()->Name(); output_tensor_name = output_args.front()->Name(); @@ -276,13 +269,15 @@ Status AddOrSkipForwardRecordBackwardWait(Graph& graph, std::string& backward_waited_event_name, std::string& forward_output_name, std::string& backward_output_name) { - if (!forward_send != !backward_recv){ - ORT_THROW("Graph requires either having both send forward node " - "and recv backword node, or none of them. Currently the graph " - "has send forward: ", forward_send, " and recv backward: ", backward_recv); + if (!forward_send != !backward_recv) { + ORT_THROW( + "Graph requires either having both send forward node " + "and recv backward node, or none of them. Currently the graph " + "has send forward: ", + forward_send, " and recv backward: ", backward_recv); } - if (!forward_send && !backward_recv){ + if (!forward_send && !backward_recv) { // Last partition doesn't have send forwrad and recv backward. No insert // needed. return Status::OK(); @@ -302,14 +297,16 @@ Status AddOrSkipForwardRecordBackwardWait(Graph& graph, // Add send forward op's output as record op's input and output for (auto& output : forward_send->MutableOutputDefs()) { - auto& new_output = CreateNodeArg(graph, output); + // send doesn't have optional output, so the node cannot be nullptr. + assert(output); + auto& new_output = CreateNodeArg(graph, *output); output_args.push_back(&new_output); input_args.push_back(output); } record_node = &CreateBottleneckNode( - graph, "RecordEvent", "forward_record", "", nullptr, - input_args, output_args); + graph, "RecordEvent", "forward_record", "", nullptr, + input_args, output_args); forward_recorded_event_name = record_node->InputDefs()[0]->Name(); forward_output_name = record_node->OutputDefs()[0]->Name(); @@ -327,13 +324,16 @@ Status AddOrSkipForwardRecordBackwardWait(Graph& graph, std::end(record_node->MutableOutputDefs())); auto& input = backward_recv->MutableInputDefs()[0]; - auto& new_output = CreateNodeArg(graph, input); + + // recv node doesn't have optional input, so the node cannot be nullptr. + assert(input); + auto& new_output = CreateNodeArg(graph, *input); output_args.push_back(&new_output); input = &new_output; wait_node = &CreateBottleneckNode( - graph, "WaitEvent", "backward_wait", "Backward pass", nullptr, - input_args, output_args); + graph, "WaitEvent", "backward_wait", "Backward pass", nullptr, + input_args, output_args); backward_waited_event_name = wait_node->InputDefs()[0]->Name(); backward_output_name = wait_node->OutputDefs()[0]->Name(); @@ -352,8 +352,8 @@ void ReplaceNodeArgs(std::vector& nodes, ORT_ENFORCE(node_args[i]->Name() != new_node_args[i]->Name()); ORT_ENFORCE(node_args[i]->Type() == new_node_args[i]->Type()); - for (auto& node: nodes) { - for (auto& node_arg: node->MutableInputDefs()) { + for (auto& node : nodes) { + for (auto& node_arg : node->MutableInputDefs()) { // Only replace when node's input name matches node_args[i]. if (node_arg->Name().compare(node_args[i]->Name()) != 0) { continue; @@ -371,11 +371,11 @@ void ReplaceNodeArgs(std::vector& nodes, } std::string AddEventBeforeNode( - Graph& graph, - Node* node, - const std::string& event_op_type, - const std::string& event_op_name, - const std::string& event_id_name) { + Graph& graph, + Node* node, + const std::string& event_op_type, + const std::string& event_op_name, + const std::string& event_id_name) { if (!node) { // No event operator is be inserted, so we don't have its input event name. return ""; @@ -410,11 +410,11 @@ std::string AddEventBeforeNode( } std::string AddEventAfterNode( - Graph& graph, - Node* node, - const std::string& event_op_type, - const std::string& event_op_name, - const std::string& event_id_name) { + Graph& graph, + Node* node, + const std::string& event_op_type, + const std::string& event_op_name, + const std::string& event_id_name) { if (!node) { // No event operator is be inserted, so we don't have its input event name. return ""; @@ -431,7 +431,7 @@ std::string AddEventAfterNode( for (size_t i = 0; i < node_args.size(); ++i) { // Find consumer of "node"'s i-th output. std::vector consumer_nodes = graph.GetMutableConsumerNodes( - node_args[i]->Name()); + node_args[i]->Name()); // Replace node_args[i] with new_node_args[i] in nodes. ReplaceNodeArgs(consumer_nodes, {node_args[i]}, {new_node_args[i]}); } @@ -457,9 +457,9 @@ Status AddForwardWaitAfterRecv( std::vector& new_input_names, std::string& event_name) { event_name = AddEventAfterNode( - graph, comm_node, - "WaitEvent", "forward_wait_after_recv", - "forward_wait_after_recv_event_id"); + graph, comm_node, + "WaitEvent", "forward_wait_after_recv", + "forward_wait_after_recv_event_id"); if (event_name.empty()) { return Status::OK(); } else { @@ -474,9 +474,9 @@ Status AddForwardRecordBeforeSend( std::vector& new_input_names, std::string& event_name) { event_name = AddEventBeforeNode( - graph, comm_node, - "RecordEvent", "forward_record_before_send", - "forward_record_before_send_event_id"); + graph, comm_node, + "RecordEvent", "forward_record_before_send", + "forward_record_before_send_event_id"); if (event_name.empty()) { return Status::OK(); } else { @@ -491,9 +491,9 @@ Status AddBackwardWaitAfterRecv( std::vector& new_input_names, std::string& event_name) { event_name = AddEventAfterNode( - graph, comm_node, - "WaitEvent", "backward_wait_after_recv", - "backward_wait_after_recv_event_id"); + graph, comm_node, + "WaitEvent", "backward_wait_after_recv", + "backward_wait_after_recv_event_id"); if (event_name.empty()) { return Status::OK(); } else { @@ -508,9 +508,9 @@ Status AddBackwardRecordBeforeSend( std::vector& new_input_names, std::string& event_name) { event_name = AddEventBeforeNode( - graph, comm_node, - "RecordEvent", "backward_record_before_send", - "backward_record_before_send_event_id"); + graph, comm_node, + "RecordEvent", "backward_record_before_send", + "backward_record_before_send_event_id"); if (event_name.empty()) { return Status::OK(); } else { @@ -605,20 +605,20 @@ Status SetInputsOutputsAndResolve(Graph& graph, // Record-2: Tell others that backward pass is done. // Record-3: Tell others that backward result has been passed to another stage. Status TransformGraphForPipeline( - Graph& graph, - const std::unordered_set& weights_to_train, - std::string& forward_waited_event_name, - std::string& forward_recorded_event_name, - std::string& backward_waited_event_name, - std::string& backward_recorded_event_name, - std::string& forward_wait_output_name, - std::string& forward_record_output_name, - std::string& backward_wait_output_name, - std::string& backward_record_output_name, - std::string& forward_waited_event_after_recv_name, - std::string& forward_recorded_event_before_send_name, - std::string& backward_waited_event_after_recv_name, - std::string& backward_recorded_event_before_send_name) { + Graph& graph, + const std::unordered_set& weights_to_train, + std::string& forward_waited_event_name, + std::string& forward_recorded_event_name, + std::string& backward_waited_event_name, + std::string& backward_recorded_event_name, + std::string& forward_wait_output_name, + std::string& forward_record_output_name, + std::string& backward_wait_output_name, + std::string& backward_record_output_name, + std::string& forward_waited_event_after_recv_name, + std::string& forward_recorded_event_before_send_name, + std::string& backward_waited_event_after_recv_name, + std::string& backward_recorded_event_before_send_name) { // Declare nodes according to their topological order. Node* forward_wait{nullptr}; Node* forward_send{nullptr}; @@ -650,27 +650,27 @@ Status TransformGraphForPipeline( std::vector new_output_names; backward_record = AddBackwardRecord( - graph, - backward_send, - new_input_names, - new_output_names, - backward_recorded_event_name, - backward_record_output_name); + graph, + backward_send, + new_input_names, + new_output_names, + backward_recorded_event_name, + backward_record_output_name); forward_wait = AddForwardWait( - graph, - forward_recv, - new_input_names, - forward_waited_event_name, - forward_wait_output_name); + graph, + forward_recv, + new_input_names, + forward_waited_event_name, + forward_wait_output_name); ORT_RETURN_IF_ERROR(AddOrSkipForwardRecordBackwardWait( - graph, - forward_send, - backward_recv, - new_input_names, - forward_recorded_event_name, - backward_waited_event_name, - forward_record_output_name, - backward_wait_output_name)); + graph, + forward_send, + backward_recv, + new_input_names, + forward_recorded_event_name, + backward_waited_event_name, + forward_record_output_name, + backward_wait_output_name)); // Different stages have different patterns of Send & Recv. // For different patterns, we add different WaitEvent and Record. @@ -694,11 +694,11 @@ Status TransformGraphForPipeline( // One and only one of is_first_stage, is_middle_stage, and is_last_stage can be true. const unsigned int stage_flag_sum = is_first_stage + is_middle_stage + is_last_stage; ORT_RETURN_IF_NOT(stage_flag_sum == 1u, - "The processed graph should be classified into a stage, " - "but we see more than one true's in the following statements. ", - "Is first stage? ", is_first_stage, ". ", - "Is middle stage? ", is_middle_stage, ". ", - "Is last stage? ", is_last_stage, "."); + "The processed graph should be classified into a stage, " + "but we see more than one true's in the following statements. ", + "Is first stage? ", is_first_stage, ". ", + "Is middle stage? ", is_middle_stage, ". ", + "Is last stage? ", is_last_stage, "."); // Now, we add Wait's in parentheses shown below. // 1. First stage: @@ -713,19 +713,17 @@ Status TransformGraphForPipeline( if (is_first_stage) { // If first stage, insert after forward WaitEvent. ORT_RETURN_IF_ERROR(AddForwardWaitAfterRecv( - graph, - forward_wait, - new_input_names, - forward_waited_event_after_recv_name - )); + graph, + forward_wait, + new_input_names, + forward_waited_event_after_recv_name)); } else if (is_middle_stage || is_last_stage) { // If middle stage or last stage, insert after forward Recv. ORT_RETURN_IF_ERROR(AddForwardWaitAfterRecv( - graph, - forward_recv, - new_input_names, - forward_waited_event_after_recv_name - )); + graph, + forward_recv, + new_input_names, + forward_waited_event_after_recv_name)); } // Now, we add Record's in parentheses shown below. @@ -740,11 +738,10 @@ Status TransformGraphForPipeline( // ----------------------> BW -> Record -> Send -> Record if (is_first_stage || is_middle_stage) { ORT_RETURN_IF_ERROR(AddForwardRecordBeforeSend( - graph, - forward_send, - new_input_names, - forward_recorded_event_before_send_name - )); + graph, + forward_send, + new_input_names, + forward_recorded_event_before_send_name)); } // Now, we add Wait's in parentheses shown below. @@ -759,11 +756,10 @@ Status TransformGraphForPipeline( // ----------------------> BW -> Record -> Send -> Record if (is_first_stage || is_middle_stage) { ORT_RETURN_IF_ERROR(AddBackwardWaitAfterRecv( - graph, - backward_recv, - new_input_names, - backward_waited_event_after_recv_name - )); + graph, + backward_recv, + new_input_names, + backward_waited_event_after_recv_name)); } // Now, we add Record's in parentheses shown below. @@ -778,18 +774,16 @@ Status TransformGraphForPipeline( // ----------------------> BW -> (Record) -> Send -> Record if (is_first_stage) { ORT_RETURN_IF_ERROR(AddBackwardRecordBeforeSend( - graph, - backward_record, - new_input_names, - backward_recorded_event_before_send_name - )); + graph, + backward_record, + new_input_names, + backward_recorded_event_before_send_name)); } else if (is_middle_stage || is_last_stage) { ORT_RETURN_IF_ERROR(AddBackwardRecordBeforeSend( - graph, - backward_send, - new_input_names, - backward_recorded_event_before_send_name - )); + graph, + backward_send, + new_input_names, + backward_recorded_event_before_send_name)); } ORT_RETURN_IF_ERROR(SetInputsOutputsAndResolve(graph, weights_to_train, new_input_names, new_output_names)); @@ -801,11 +795,11 @@ Status TransformGraphForPipeline( // It also cerates an initializer to store its value. template void AddNewScalarNodeArgAndInitializer(Graph& graph, - const std::string& op_name, - onnx::TensorProto_DataType type, - T data, - std::vector& new_node_args, - std::vector& new_names) { + const std::string& op_name, + onnx::TensorProto_DataType type, + T data, + std::vector& new_node_args, + std::vector& new_names) { AddNewNodeArg(graph, op_name, type, new_node_args, new_names); ONNX_NAMESPACE::TensorProto proto_data; @@ -825,6 +819,218 @@ void AddNewScalarNodeArgAndInitializer(Graph& graph, graph.AddInitializedTensor(proto_data); } +// Given a node, this function finds all its connected nodes (consumer nodes and producer nodes) and +// connected inputs and outputs in the given graph, then adds them to the containers passed in. +Status FindAllConnectedNodes(Graph& graph, + Node* node, + std::vector& connected_nodes, + std::set& connected_inputs, + std::set& connected_outputs + ) { + assert(node); + ORT_THROW_IF_ERROR(node->ForEachMutableWithIndex( + node->MutableInputDefs(), + [&](NodeArg& node_arg, size_t /*index*/) { + if (graph.IsInputsIncludingInitializers(&node_arg) || graph.IsInitializedTensor(node_arg.Name())) { + connected_inputs.insert(&node_arg); + } else { + Node* producer_node = graph.GetMutableProducerNode(node_arg.Name()); + if (producer_node == nullptr) { + // got nullptr as producer node. This could be because the input is a constant op which will be optimized + // away. Print out this information and continue. + // TODO: re-visit the different cases to see if there are other situations aside from constant ops. + LOGS_DEFAULT(WARNING) << "Cannot find producer node for node_arg: " << node_arg.Name() << ". Skipping this node."; + } else { + connected_nodes.push_back(producer_node); + } + } + return Status::OK(); + })); + + ORT_THROW_IF_ERROR(node->ForEachMutableWithIndex( + node->MutableOutputDefs(), + [&](NodeArg& node_arg, size_t /*index*/) { + if (!graph.IsOutput(&node_arg)) { + std::vector consumer_nodes = graph.GetMutableConsumerNodes(node_arg.Name()); + connected_nodes.insert(std::end(connected_nodes), consumer_nodes.begin(), consumer_nodes.end()); + + } else { + connected_outputs.insert(&node_arg); + } + return Status::OK(); + })); + return Status::OK(); +} + +// PipelineStageNodeGroup groups nodes that share the same input initializer and belong to the same stage. +// It is used to distinguish other nodes that share the same input initializer but belong to +// other pipeline partitions after split. +struct PipelineStageNodeGroup { + const size_t stage_id; + + // Vector of nodes that have the same initializer input and belong to the same stage. Noted that + // the consumer nodes of a particular initializer can be more than one, so we need a vector to store those + // nodes. + std::vector nodes; + PipelineStageNodeGroup(const size_t stage, std::vector& node_group) : stage_id(stage), nodes(std::move(node_group)){}; +}; + +// This function passes through the given initializer across stages specified in node_groups[i].stage_id. +// This applies to the case when initializer is used in multiple stages, say stage a and stage b (a& node_groups, + const std::vector& send_nodes, + const std::vector& recv_nodes) { + assert(initializer); + ORT_ENFORCE(node_groups.size() >= 2, "Initializer ", initializer->Name(), " is not shared across stages."); + + const size_t from_stage = node_groups.front().stage_id; + const size_t to_stage = node_groups.back().stage_id; + + ORT_ENFORCE(from_stage < to_stage, "Pass through from_stage (", from_stage, + ") is not less than the to_stage (", to_stage, ")."); + + auto dtype = initializer->TypeAsProto()->tensor_type().elem_type(); + + auto current_node_arg = initializer; + + size_t node_group_index = 1; + for (auto i = from_stage; i < to_stage; ++i) { + // processing send node in cut i + auto& send_attributes = send_nodes[i]->GetMutableAttributes(); + auto& send_element_types = send_attributes["element_types"]; + send_element_types.add_ints(static_cast(dtype)); + send_nodes[i]->MutableInputDefs().push_back(current_node_arg); + send_nodes[i]->MutableInputArgsCount().back()++; + + // Create a new node_arg for the recv, as the new node_arg from recv node should possess a different id + // than the one in send + assert(current_node_arg); + current_node_arg = &CreateNodeArg(graph, *current_node_arg); + + // process recv node in cut i + auto& recv_attributes = recv_nodes[i]->GetMutableAttributes(); + auto& recv_element_types = recv_attributes["element_types"]; + recv_element_types.add_ints(static_cast(dtype)); + recv_nodes[i]->MutableOutputDefs().push_back(current_node_arg); + + // update the consumer node's input if the node's group is not in the first partition + if (i > from_stage && node_groups[node_group_index].stage_id == (i + 1)) { + for (auto node : node_groups[node_group_index].nodes) { + for (auto& input_node : node->MutableInputDefs()) { + if (input_node == initializer) { + input_node = current_node_arg; + break; + } + } + } + node_group_index++; + } + } + + ORT_ENFORCE(node_group_index == node_groups.size(), "Not all nodes are updated with new initializer."); + + return Status::OK(); +} + +// Traverse the graph to find out all connected elements in the graph from start_node. The traverse treats the graph as an +// undirected graph. +void TraverseGraphWithConnectedElement(Graph& graph, + Node* start_node, + std::set& visited_nodes, + std::set& visited_inputs, + std::set& visited_outputs) { + assert(start_node); + visited_nodes.clear(); + visited_inputs.clear(); + visited_outputs.clear(); + + std::queue node_queue; + node_queue.push(start_node); + + while (!node_queue.empty()) { + auto node = node_queue.front(); + node_queue.pop(); + if (visited_nodes.insert(node).second) { + std::vector connected_nodes; + ORT_THROW_IF_ERROR(FindAllConnectedNodes(graph, node, connected_nodes, visited_inputs, visited_outputs)); + + for (auto n : connected_nodes) { + ORT_ENFORCE(n != nullptr, "Found nullptr in searching for connected nodes"); + node_queue.push(n); + } + } + } +} + +// If an initializer is shared across partitions, instead of creating a separate all_reduce op to +// sync with those tensors in selected partitions, we save only one copy of that initializer in +// the very first partition it appears, and pass that data down to all following partitions +// where this initializer is used. +common::Status HandleSharedInitializer(Graph& graph, + const std::vector& send_nodes, + const std::vector& recv_nodes) { + // Map a given initializer to all the partitions that its consumer nodes reside. The size of + // the mapped vector reflects how many partitions this initializer's consumer nodes distribute. + // If its size is greater than 1, it means this initializer is being used in more than one partition and + // we need to proceed those cases. + std::map> input_consumer_stage_map; + + for (size_t stage = 0; stage <= send_nodes.size(); ++stage) { + std::set visited_nodes; + std::set visited_inputs; + std::set visited_outputs; + + // send_nodes[i] is the Send op in i-th stage's forward pass. recv_nodes[i] is the Recv in the (i+1)-th stage's + // forward pass. When not in last stage, traverse start from send node; otherwise start with the recv node as + // send node doesn't exist in last partition's forward pass. + Node* traverse_start_node = stage < send_nodes.size() ? send_nodes[stage] : recv_nodes.back(); + TraverseGraphWithConnectedElement(graph, + traverse_start_node, + visited_nodes, + visited_inputs, + visited_outputs); + + for (const auto input : visited_inputs) { + // If the node is an input instead of an initializer, continue + if (!graph.IsInitializedTensor(input->Name())){ + continue; + } + + // group all consumer nodes that shares the same input initializer in visited_consumer_nodes + std::vector consumer_nodes = graph.GetMutableConsumerNodes(input->Name()); + std::vector visited_consumer_nodes; + for(auto consumer_node : consumer_nodes){ + if (visited_nodes.count(consumer_node) != 0){ + visited_consumer_nodes.push_back(consumer_node); + } + } + + if (input_consumer_stage_map.count(input) == 0) { + input_consumer_stage_map[input] = std::vector{ + PipelineStageNodeGroup(stage, visited_consumer_nodes)}; + } else { + input_consumer_stage_map[input].push_back({stage, visited_consumer_nodes}); + } + } + } + + for (const auto& entry : input_consumer_stage_map) { + // If any initializer is shared, handle the logic of passing it from the first seen stage all + // the way to last seen stage. + if (entry.second.size() > 1) { + ORT_RETURN_IF_ERROR(AddPassthroughInitializer(graph, + entry.first, // initializer node_arg + entry.second, // initializer consumer node groups + send_nodes, + recv_nodes)); + } + } + return Status::OK(); +} + // split the graph into disconnected subgraph based on provided CutInfo common::Status SplitGraph(Graph& graph, std::vector split_edge_groups, @@ -869,30 +1075,30 @@ common::Status SplitGraph(Graph& graph, auto cut_index_str = std::to_string(index); // add input node_arg and initializer for send/recv AddNewScalarNodeArgAndInitializer(graph, - "send_input_signal" + cut_index_str, - ONNX_NAMESPACE::TensorProto_DataType_BOOL, - true, /* initializer data */ - send_input_args, - new_input_names); + "send_input_signal" + cut_index_str, + ONNX_NAMESPACE::TensorProto_DataType_BOOL, + true, /* initializer data */ + send_input_args, + new_input_names); AddNewScalarNodeArgAndInitializer(graph, - "recv_input_signal" + cut_index_str, - ONNX_NAMESPACE::TensorProto_DataType_BOOL, - true, /* initializer data */ - recv_input_args, - new_input_names); + "recv_input_signal" + cut_index_str, + ONNX_NAMESPACE::TensorProto_DataType_BOOL, + true, /* initializer data */ + recv_input_args, + new_input_names); AddNewScalarNodeArgAndInitializer(graph, - "send_dst_rank" + cut_index_str, - ONNX_NAMESPACE::TensorProto_DataType_INT64, - index + 1, /* initializer data */ - send_input_args, - new_input_names); + "send_dst_rank" + cut_index_str, + ONNX_NAMESPACE::TensorProto_DataType_INT64, + index + 1, /* initializer data */ + send_input_args, + new_input_names); AddNewScalarNodeArgAndInitializer(graph, - "recv_src_rank" + cut_index_str, - ONNX_NAMESPACE::TensorProto_DataType_INT64, - index, /* initializer data */ - recv_input_args, - new_input_names); + "recv_src_rank" + cut_index_str, + ONNX_NAMESPACE::TensorProto_DataType_INT64, + index, /* initializer data */ + recv_input_args, + new_input_names); // add output node_arg for send/recv AddNewNodeArg(graph, "send_output_signal" + cut_index_str, ONNX_NAMESPACE::TensorProto_DataType_BOOL, send_output_args, new_output_names); @@ -948,6 +1154,7 @@ common::Status SplitGraph(Graph& graph, if (exiting_updated_node_arg != updated_node_args.end()) { updated_node_arg = exiting_updated_node_arg->second; } + assert(updated_node_arg); send_input_args.push_back(updated_node_arg); @@ -955,7 +1162,7 @@ common::Status SplitGraph(Graph& graph, element_types.add_ints(static_cast(dtype)); - auto& new_receive_output = CreateNodeArg(graph, updated_node_arg); + auto& new_receive_output = CreateNodeArg(graph, *updated_node_arg); const auto old_shape = *(updated_node_arg->Shape()); new_receive_output.SetShape(old_shape); recv_output_args.push_back(&new_receive_output); @@ -973,7 +1180,7 @@ common::Status SplitGraph(Graph& graph, // deal with updating the consumer's input node_args std::vector consumer_nodes; if (id.consumer_nodes.has_value()) { - for(auto& consumer_node_id : id.consumer_nodes.value()){ + for (auto& consumer_node_id : id.consumer_nodes.value()) { consumer_nodes.push_back(graph.GetMutableProducerNode(consumer_node_id)); } } else { @@ -1019,68 +1226,17 @@ common::Status SplitGraph(Graph& graph, return Status::OK(); } -Status FindAllConnectedNodes(Graph& graph, - const Node* node, - std::vector& connected_nodes, - std::set& connected_inputs, - std::set& connected_outputs) { - ORT_THROW_IF_ERROR(node->ForEachWithIndex( - node->InputDefs(), - [&](const NodeArg& node_arg, size_t /*index*/) { - if (graph.IsInputsIncludingInitializers(&node_arg)) { - connected_inputs.insert(&node_arg); - } else { - const Node* producer_node = graph.GetProducerNode(node_arg.Name()); - if (producer_node == nullptr) { - // got nullptr as producer node. This could be because the input is a constant op which will be optimized - // away. Print out this information and continue. - LOGS_DEFAULT(WARNING) << "Cannot find producer node for node_arg: " << node_arg.Name() << ". Skipping this node."; - } else { - connected_nodes.push_back(producer_node); - } - } - return Status::OK(); - })); - - ORT_THROW_IF_ERROR(node->ForEachWithIndex( - node->OutputDefs(), - [&](const NodeArg& node_arg, size_t /*index*/) { - if (!graph.IsOutput(&node_arg)) { - std::vector consumer_nodes = graph.GetConsumerNodes(node_arg.Name()); - connected_nodes.insert(std::end(connected_nodes), consumer_nodes.begin(), consumer_nodes.end()); - - } else { - connected_outputs.insert(&node_arg); - } - return Status::OK(); - })); - return Status::OK(); -} - // traverse the graph from start_node to get the set of nodes contains in this disconnected subgraph -common::Status GenerateSubgraph(Graph& graph, const Node* start_node) { - std::queue node_queue; - node_queue.push(start_node); - - std::set visited_nodes; - std::set visited_inputs; - std::set visited_outputs; +common::Status GenerateSubgraph(Graph& graph, Node* start_node) { + assert(start_node); + std::set visited_nodes; + std::set visited_inputs; + std::set visited_outputs; // BFS graph traverse - while (!node_queue.empty()) { - auto node = node_queue.front(); - node_queue.pop(); - if (visited_nodes.count(node) == 0) { - visited_nodes.insert(node); - std::vector connected_nodes; - ORT_THROW_IF_ERROR(FindAllConnectedNodes(graph, node, connected_nodes, visited_inputs, visited_outputs)); + TraverseGraphWithConnectedElement(graph, start_node, + visited_nodes, visited_inputs, visited_outputs); - for (auto n : connected_nodes) { - ORT_ENFORCE(n!=nullptr, "Found nullptr in searching for connected nodes"); - node_queue.push(n); - } - } - } std::set visited_node_index; for (auto n : visited_nodes) { visited_node_index.insert(n->Index()); @@ -1090,8 +1246,8 @@ common::Status GenerateSubgraph(Graph& graph, const Node* start_node) { const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); // reverse iterate the nodes in tolopogical order, and delete those not visited - for (auto it = node_topology_list.rbegin(); it != node_topology_list.rend(); it++){ - if (visited_node_index.count(*it)==0){ + for (auto it = node_topology_list.rbegin(); it != node_topology_list.rend(); it++) { + if (visited_node_index.count(*it) == 0) { graph.RemoveNode(*it); } } @@ -1131,18 +1287,60 @@ Status ApplyPipelinePartitionToMainGraph( std::vector send_nodes, recv_nodes; send_nodes.reserve(split_count); recv_nodes.reserve(split_count); + + // Split the graph by cutting edges specified in cut_info. After this function, the graph will be + // composed of several disconnected partitions. ORT_RETURN_IF_ERROR(SplitGraph(graph, cut_info, send_nodes, recv_nodes)); if (send_nodes.size() != split_count || recv_nodes.size() != split_count) { ORT_THROW("Split error: not all cut has Send and Recv inserted. Send node count: ", - send_nodes.size(), ", Recv node count: ", recv_nodes.size(), ", split count: ", split_count); + send_nodes.size(), ", Recv node count: ", recv_nodes.size(), ", split count: ", split_count); } + // Check to see if there are any initializers that is being shared between different partitions. If there + // is, keep the initializer in the first seen partition and have it pass through by send/recv to the others. + ORT_RETURN_IF_ERROR(HandleSharedInitializer(graph, send_nodes, recv_nodes)); + + // Now remove the partitions that are not tie to the current pipeline stage and generate the sub-graph. if (pipeline_stage_id < split_count) { ORT_RETURN_IF_ERROR(GenerateSubgraph(graph, send_nodes[pipeline_stage_id])); } else { ORT_RETURN_IF_ERROR(GenerateSubgraph(graph, recv_nodes.back())); } + + // Post check to ensure the curent partition is correct and matches with Send/Recv nodes inserted during split. + Node* send_node{nullptr}; + Node* recv_node{nullptr}; + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Send") { + send_node = &node; + } else if (node.OpType() == "Recv") { + recv_node = &node; + } + } + + if (pipeline_stage_id == 0){ + // For the first stage, there should be no recv node, and the send node contained in graph should match the first + // send_node inserted during split. + ORT_ENFORCE(recv_node == nullptr, "Error: first stage contains Recv node in forward pass."); + ORT_ENFORCE(send_node == send_nodes[0], + "Error: first stage doesn't contain the right Send node. Possibly CutInfo data is wrong."); + } + else if (pipeline_stage_id == split_count){ + // For the last stage, there should be no send node, and the recv node contained in graph should match the last + // recv_node inserted during split. + ORT_ENFORCE(recv_node == recv_nodes.back(), + "Error: last stage doesn't contain the right Recv node. Possibly CutInfo data is wrong."); + ORT_ENFORCE(send_node == nullptr, "Error: last stage contains Send node in forward pass."); + } else { + // For stages in the middle, i-th stage should contain recv node that matches the (i-1)-th inserted recv node, and the i-th + // inserted send node. + ORT_ENFORCE(recv_node == recv_nodes[pipeline_stage_id - 1], + "Error: stage ", pipeline_stage_id, " doesn't contain the right Recv node. Possibly CutInfo data is wrong."); + ORT_ENFORCE(send_node == send_nodes[pipeline_stage_id], + "Error: stage ", pipeline_stage_id, " doesn't contain the right Send node. Possibly CutInfo data is wrong."); + } + return Status::OK(); } diff --git a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc index 6dbb2349ef..e5ff643cec 100644 --- a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc +++ b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc @@ -1117,7 +1117,114 @@ void RetrieveSendRecvOperators( } } -TEST(GradientGraphBuilderTest, PipelineOnlinePartition) { +PathString GenerateFileNameWithIndex(const PathString& base_str, int index, const PathString& file_suffix) { + Path p; + ORT_ENFORCE(Path::Parse(base_str, p).IsOK()); + return p.ConcatIndex(index).Concat(file_suffix).ToPathString(); +} + +TEST(GradientGraphBuilderTest, PipelineOnlinePartition_bert_tiny) { + const auto model_path = ORT_TSTR("testdata/bert_toy_optimized.onnx"); + + const size_t total_partition_count = 3; + TrainingSession::TrainingConfiguration::PipelineConfiguration pipe{}; + pipe.do_partition = true; + + // cut model in 3 partitions + TrainingSession::TrainingConfiguration::CutInfo cut0 = { + onnxruntime::training::TrainingSession::TrainingConfiguration::CutEdge("326"), + onnxruntime::training::TrainingSession::TrainingConfiguration::CutEdge("103", {"413", "529"})}; + + TrainingSession::TrainingConfiguration::CutInfo cut1 = { + onnxruntime::training::TrainingSession::TrainingConfiguration::CutEdge("558"), + onnxruntime::training::TrainingSession::TrainingConfiguration::CutEdge("103", {"645"})}; + + pipe.cut_list.emplace_back(cut0); + pipe.cut_list.emplace_back(cut1); + + TrainingSession::TrainingConfiguration::MixedPrecisionConfiguration mixed_precision_config{}; + mixed_precision_config.use_fp16_initializers = true; + + // 2 test variations - full precision and mixed precision + const std::vector test_with_fp32{true, false}; + for (auto is_fp32 : test_with_fp32) { + // graph is partitioned into 3 parts. + for (int i = 0; i < static_cast(total_partition_count); ++i) { + + PathString output_file = GenerateFileNameWithIndex(ORT_TSTR("pipeline_partition_"), i, ORT_TSTR("_back.onnx")); + auto config = MakeBasicTrainingConfig(); + + if (i == static_cast(total_partition_count - 1)) { + config.loss_function_config = TrainingSession::TrainingConfiguration::LossFunctionConfiguration{}; + config.loss_function_config.value().loss_function_info = + LossFunctionInfo(OpDef("BertLoss", kOnnxDomain), + "total_loss", + {/*prediction_masked_lm*/ "prediction_scores", + /*prediction_next_sentence*/ "seq_relationship_score", + /*masked_lm_positions*/ "masked_lm_positions", + /*masked_lm_ids*/ "masked_lm_ids", + /*masked_lm_weights*/ "masked_lm_weights", + /*next_sentence_labels*/ "next_sentence_labels", + /*mlm_loss*/ "mlm_loss", + /*nsp_loss*/ "nsp_loss"}); + } + + // Add weight_names_to_not_train to avoid generating backward graph on those tensor + config.weight_names_to_not_train = { + "position_01", // Slice's dat input + "op_min_ends_expand_10", //op_min_ends_expand_10 + }; + + config.pipeline_config = pipe; + config.distributed_config.world_rank = i; + config.distributed_config.world_size = total_partition_count; + config.distributed_config.local_rank = i; + config.distributed_config.local_size = total_partition_count; + config.distributed_config.data_parallel_size = 1; + config.distributed_config.horizontal_parallel_size = 1; + config.distributed_config.pipeline_parallel_size = total_partition_count; + config.model_with_training_graph_path = output_file; + + if (!is_fp32) { + config.mixed_precision_config = mixed_precision_config; + } + + PathString backprop_model_file; + Status status = BuildBackPropGraph(model_path, config, backprop_model_file); + ASSERT_TRUE(status.IsOK()) << status << " (is_fp32 = " << is_fp32 << ", stage = " << i << ").\n"; + + // Skip the re-load for mixed-precision case. This model contains grad op that has function body, + // which takes a const tensor input. Const cast for input in function body won't be saved in the output + // model so reload will run into error. + // For the purpose of testing mixed-precision, BuildBackPropGraph above will be sufficient to verify the + // partition logic and validate the graph. + if (is_fp32) { + std::shared_ptr model; + // Ensure the partitioned model load. + status = Model::Load(backprop_model_file, model, nullptr, DefaultLoggingManager().DefaultLogger()); + ASSERT_TRUE(status.IsOK()) << status << " (is_fp32 = " << is_fp32 << ", stage = " << i << ").\n"; + + // verify the first stage contains word embedding as input and the last stage doesn't + auto model_proto = model->ToProto(); + const auto& graph_proto = model_proto.graph(); + + bool found_word_embedding = false; + for (auto& tensor : graph_proto.initializer()) { + if (tensor.name() == "bert.embeddings.word_embeddings.weight") { + found_word_embedding = true; + } + } + if (i == 0) { + ASSERT_TRUE(found_word_embedding) << " (is_fp32 = " << is_fp32 << ", stage = " << i << ").\n"; + } else { + ASSERT_FALSE(found_word_embedding) << " (is_fp32 = " << is_fp32 << ", stage = " << i << ").\n"; + } + } + } + } +} + +TEST(GradientGraphBuilderTest, PipelineOnlinePartition_MLP) { auto model_uri = ORIGINAL_MODEL_PATH; TrainingSession::TrainingConfiguration::PipelineConfiguration pipe{}; @@ -1137,14 +1244,10 @@ TEST(GradientGraphBuilderTest, PipelineOnlinePartition) { for(auto is_fp32 : test_with_fp32) { // graph is partitioned into 3 parts. for (int i = 0; i < 3; ++i) { -#ifdef _WIN32 - auto surfix = std::to_wstring(i); -#else - auto surfix = std::to_string(i); -#endif - PathString output_file = ORT_TSTR("pipeline_partition_") + surfix + ORT_TSTR("_back.onnx"); + PathString output_file = GenerateFileNameWithIndex(ORT_TSTR("pipeline_partition_"), i, ORT_TSTR("_back.onnx")); auto config = MakeBasicTrainingConfig(); + config.pipeline_config = pipe; config.distributed_config.world_rank = i; config.distributed_config.world_size = 3; @@ -1178,9 +1281,62 @@ TEST(GradientGraphBuilderTest, PipelineOnlinePartition) { } } +Status RunOnlinePartition(const std::vector& cut_list, + int pipeline_stage_size, + std::set status_check_stages = {}) { + auto model_uri = ORIGINAL_MODEL_PATH; + + TrainingSession::TrainingConfiguration::PipelineConfiguration pipe{}; + pipe.do_partition = true; + pipe.cut_list = cut_list; + + for (int i = 0; i < pipeline_stage_size; ++i) { + PathString output_file = GenerateFileNameWithIndex(ORT_TSTR("pipeline_partition_"), i, ORT_TSTR("_back.onnx")); + + auto config = MakeBasicTrainingConfig(); + config.pipeline_config = pipe; + + config.distributed_config.world_rank = i; + config.distributed_config.world_size = pipeline_stage_size; + config.distributed_config.local_rank = i; + config.distributed_config.local_size = pipeline_stage_size; + config.distributed_config.data_parallel_size = 1; + config.distributed_config.horizontal_parallel_size = 1; + config.distributed_config.pipeline_parallel_size = pipeline_stage_size; + config.model_with_training_graph_path = output_file; + + PathString backprop_model_file; + if (status_check_stages.count(i) > 0) { + auto status = BuildBackPropGraph(model_uri, config, backprop_model_file); + EXPECT_FALSE(status.IsOK()); + } else { + EXPECT_THROW(BuildBackPropGraph(model_uri, config, backprop_model_file), OnnxRuntimeException); + } + } + return Status::OK(); +} + +TEST(GradientGraphBuilderTest, PipelineOnlinePartition_Invalid_Input) { + using CutEdge = TrainingSession::TrainingConfiguration::CutEdge; + using CutInfo = TrainingSession::TrainingConfiguration::CutInfo; + + // Test with invalid cut edge + TrainingSession::TrainingConfiguration::CutInfo invalid_cut_edge = {TrainingSession::TrainingConfiguration::CutEdge("3")}; + ASSERT_STATUS_OK(RunOnlinePartition(std::vector{invalid_cut_edge}, 2 /* pipeline_stage_size */)); + + // Test mis-matched cut list with stage size + TrainingSession::TrainingConfiguration::CutInfo cut_edge = {TrainingSession::TrainingConfiguration::CutEdge("T3")}; + ASSERT_STATUS_OK(RunOnlinePartition(std::vector{cut_edge}, 3 /* pipeline_stage_size */)); + + // Test unordered cut_info list + CutInfo cut0 = {CutEdge("T3")}; + CutInfo cut1 = {CutEdge("T6")}; + ASSERT_STATUS_OK(RunOnlinePartition(std::vector{cut1, cut0}, 3 /* pipeline_stage_size */, {0, 2} /* status_check_stages */)); +} + // verify pipeline config can load and gradient graph can construct. TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) { - PathString filename_base = ORT_TSTR("testdata/test_training_model_"); + PathString filename_base = ORT_TSTR("testdata/test_training_model_"); auto load_and_check_gradient_graph = [](int stageIdx, PathString& input_file, PathString& output_file) { auto config = MakeBasicTrainingConfig(); @@ -1286,13 +1442,9 @@ TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) { }; for (int i = 0; i < 3; ++i) { -#ifdef _WIN32 - auto surfix = std::to_wstring(i); -#else - auto surfix = std::to_string(i); -#endif - PathString input_file = filename_base + surfix + ORT_TSTR(".onnx"); - PathString output_file = filename_base + surfix + ORT_TSTR("_back.onnx"); + PathString input_file = GenerateFileNameWithIndex(filename_base, i, ORT_TSTR(".onnx")); + PathString output_file = GenerateFileNameWithIndex(filename_base, i, ORT_TSTR("_back.onnx")); + load_and_check_gradient_graph(i, input_file, output_file); } } @@ -1358,12 +1510,7 @@ TEST(GradientGraphBuilderTest, TrainingSession_WithPipeline) { std::vector sub_model_files(num_subs); for (size_t sub_id = 0; sub_id < num_subs; ++sub_id) { -#ifdef _WIN32 - auto sub_id_str = std::to_wstring(sub_id); -#else - auto sub_id_str = std::to_string(sub_id); -#endif - sub_model_files[sub_id] = ORT_TSTR("sub_") + sub_id_str + ORT_TSTR(".onnx"); + sub_model_files[sub_id] = GenerateFileNameWithIndex(ORT_TSTR("sub_"), sub_id, ORT_TSTR(".onnx")); } PipelineSplitter splitter; @@ -1383,12 +1530,7 @@ TEST(GradientGraphBuilderTest, TrainingSession_WithPipeline) { for (size_t sub_id = 0; sub_id < num_subs; ++sub_id) { auto& sub_sess = subs[sub_id]; sub_sess.so.enable_profiling = true; -#ifdef _WIN32 - auto sub_id_str = std::to_wstring(sub_id); -#else - auto sub_id_str = std::to_string(sub_id); -#endif - sub_sess.so.profile_file_prefix = ORT_TSTR("pipeline") + sub_id_str; + sub_sess.so.profile_file_prefix = GenerateFileNameWithIndex(ORT_TSTR("pipeline"), static_cast(sub_id), ORT_TSTR("")); sub_sess.run_options.run_log_verbosity_level = sub_sess.so.session_log_verbosity_level; sub_sess.run_options.run_tag = sub_sess.so.session_logid;