From 633008b5ef1ef2fca7d16aef667b3d867a1b558d Mon Sep 17 00:00:00 2001 From: Xueyun Zhu <40807589+xzhu1900@users.noreply.github.com> Date: Tue, 26 May 2020 17:44:09 -0700 Subject: [PATCH] Add pipeline online partition logic for pipeline (#3996) * online partition * fix when multiple consumer nodes is in cut info * fix windows build * address feedback * adding test * feedback * address feedback * add parser for cut edge * windows build --- include/onnxruntime/core/graph/graph.h | 11 + include/onnxruntime/core/graph/graph_nodes.h | 17 +- onnxruntime/core/graph/graph.cc | 7 + .../core/graph/pipeline_transformer.cc | 442 ++++++++++++++++-- .../core/graph/pipeline_transformer.h | 9 +- .../core/session/training_session.cc | 21 +- .../core/session/training_session.h | 30 +- orttraining/orttraining/models/bert/main.cc | 57 +++ .../models/runner/training_runner.cc | 30 +- .../models/runner/training_runner.h | 15 +- .../test/graph/gradient_graph_builder_test.cc | 50 +- 11 files changed, 612 insertions(+), 77 deletions(-) diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 9661633eb9..717f5b21ce 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -525,6 +525,11 @@ class Graph { return graph_inputs_including_initializers_; } + /** Return true if "node_arg" is a input or an initializer. Otherwise, returns false. */ + bool IsInputsIncludingInitializers(const NodeArg* node_arg) const noexcept{ + return std::find(graph_inputs_including_initializers_.begin(), graph_inputs_including_initializers_.end(), node_arg) != graph_inputs_including_initializers_.end(); + } + /** Gets the Graph inputs that are initializers These are overridable initializers. This is a difference between graph_inputs_including_initializers_ and graph_inputs_excluding_initializers_ @@ -537,6 +542,10 @@ class Graph { @remarks Contains no nullptr values.*/ const std::vector& GetOutputs() const noexcept { return graph_outputs_; } + bool IsOutput(const NodeArg* node_arg) const noexcept{ + return std::find(graph_outputs_.begin(), graph_outputs_.end(), node_arg) != graph_outputs_.end(); + } + /** Returns a vector with the indexes of the outputs of the given Node that are also Graph outputs. */ std::vector GetNodeOutputsInGraphOutputs(const Node& node) const { int output_idx = 0; @@ -557,6 +566,8 @@ class Graph { @remarks Contains no nullptr values. */ const std::vector& GetValueInfo() const noexcept; + void AddValueInfo(const NodeArg* new_value_info); + /** Gets the Node with the specified node index. @returns Node instance if found. nullptr if node_index is invalid or node has been freed. */ diff --git a/include/onnxruntime/core/graph/graph_nodes.h b/include/onnxruntime/core/graph/graph_nodes.h index d565c3525d..2deba89b70 100644 --- a/include/onnxruntime/core/graph/graph_nodes.h +++ b/include/onnxruntime/core/graph/graph_nodes.h @@ -12,7 +12,7 @@ namespace onnxruntime { class Node; /** -Class to filter out null entries from either a vector of unique_ptr or a vector of [const] Node* and +Class to filter out null entries from either a vector of unique_ptr or a vector of [const] Node* and provide an iterator interface that returns [const] Node& for the valid entries. */ template @@ -29,6 +29,7 @@ class ValidNodes { using ConstNodeIterator = NodeIterator; using MutableNodeIterator = NodeIterator; + using ConstReverseNodeIterator = NodeIterator; ConstNodeIterator cbegin() const noexcept { return {nodes_.cbegin(), nodes_.cend()}; @@ -46,6 +47,14 @@ class ValidNodes { return cend(); } + ConstReverseNodeIterator rbegin() const noexcept { + return {nodes_.crbegin(), nodes_.crend()}; + } + + ConstReverseNodeIterator rend() const noexcept { + return {nodes_.crend(), nodes_.crend()}; + } + MutableNodeIterator begin() noexcept { return {nodes_.begin(), nodes_.end()}; } @@ -56,10 +65,10 @@ class ValidNodes { bool empty() const noexcept { return nodes_.empty(); } - /** + /** @class NodeIterator Iterator to provide const and non-const access to valid Node instances in a Graph. - @remarks Skips invalid nodes. + @remarks Skips invalid nodes. */ template class NodeIterator { @@ -130,7 +139,7 @@ class ValidNodes { }; /** -Class that provides iteration over all valid nodes in the Graph. +Class that provides iteration over all valid nodes in the Graph. */ class GraphNodes : public ValidNodes>> { public: diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 16b6049189..eac2e7e9ec 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2337,6 +2337,13 @@ const std::vector& Graph::GetValueInfo() const noexcept { return value_info_; } +void Graph::AddValueInfo(const NodeArg* new_value_info){ + for(const auto* info : value_info_){ + ORT_ENFORCE(info->Name() != new_value_info->Name(), "Error: trying to add an existing value info."); + } + value_info_.push_back(new_value_info); +} + std::vector Graph::CreateNodeArgs(const google::protobuf::RepeatedPtrField& names, const ArgNameToTypeMap& name_to_type_map) { const auto name_to_type_map_end = name_to_type_map.end(); diff --git a/orttraining/orttraining/core/graph/pipeline_transformer.cc b/orttraining/orttraining/core/graph/pipeline_transformer.cc index 85b89a2e44..4ca48178e9 100644 --- a/orttraining/orttraining/core/graph/pipeline_transformer.cc +++ b/orttraining/orttraining/core/graph/pipeline_transformer.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "orttraining/core/graph/pipeline_transformer.h" +#include using namespace onnxruntime::common; @@ -22,21 +23,22 @@ bool IsBackward(Node& node) { return (node.Description() == "Backward pass"); } -NodeArg& CreateInt64NodeArg(Graph& graph, const std::string& name) { +NodeArg& CreateTypedNodeArg(Graph& graph, onnx::TensorProto_DataType type, const std::string& name) { ONNX_NAMESPACE::TypeProto type_proto; - type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + type_proto.mutable_tensor_type()->set_elem_type(type); auto actual_name = graph.GenerateNodeArgName(name); auto& node_arg = graph.GetOrCreateNodeArg(actual_name, &type_proto); return node_arg; } -void AddInputEvent(Graph& graph, - const std::string& event_name, - std::vector& input_args, - std::vector& new_input_names) { - auto& event_id = CreateInt64NodeArg(graph, event_name); - new_input_names.push_back(event_id.Name()); - input_args.push_back(&event_id); +void AddNewNodeArg(Graph& graph, + const std::string& op_name, + onnx::TensorProto_DataType type, + std::vector& new_node_args, + std::vector& new_names) { + auto& new_node_arg = CreateTypedNodeArg(graph, type, op_name); + new_names.push_back(new_node_arg.Name()); + new_node_args.push_back(&new_node_arg); } // gradient graph can contain some dangling leaf nodes. Add them all to WaitEvent @@ -112,7 +114,7 @@ Node& CreateBottleneckNode(Graph& graph, nullptr /* assume all bottleneck node have no attributes */, kMSDomain); } - + Node* AddBackwardRecord(Graph& graph, Node* backward_send, std::vector& new_input_names, @@ -120,7 +122,8 @@ Node* AddBackwardRecord(Graph& graph, std::string &event_id_tensor_name, std::string &output_tensor_name) { std::vector input_args; - AddInputEvent(graph, "backward_recorded_event_id", input_args, new_input_names); + AddNewNodeArg(graph, "backward_recorded_event_id", ONNX_NAMESPACE::TensorProto_DataType_INT64, + input_args, new_input_names); std::vector output_args{}; if (backward_send) { @@ -175,10 +178,11 @@ Node* AddForwardWait(Graph& graph, std::vector input_args; std::vector output_args; - AddInputEvent(graph, "forward_waited_event_id", input_args, new_input_names); + AddNewNodeArg(graph, "forward_waited_event_id", ONNX_NAMESPACE::TensorProto_DataType_INT64, + input_args, new_input_names); const std::vector& graph_inputs = graph.GetInputsIncludingInitializers(); - if (graph_inputs.size() == 0){ + if (graph_inputs.size() == 0) { ORT_THROW("Graph ", graph.Name(), " doesn't have any inputs."); } @@ -237,7 +241,8 @@ Status AddOrSkipForwardRecordBackwardWait(Graph& graph, { std::vector input_args; std::vector output_args; - AddInputEvent(graph, "forward_recorded_event_id", input_args, new_input_names); + AddNewNodeArg(graph, "forward_recorded_event_id", ONNX_NAMESPACE::TensorProto_DataType_INT64, + input_args, new_input_names); // Add send forward op's output as record op's input and output for (auto& output : forward_send->MutableOutputDefs()) { @@ -258,7 +263,8 @@ Status AddOrSkipForwardRecordBackwardWait(Graph& graph, { std::vector input_args; std::vector output_args; - AddInputEvent(graph, "backward_waited_event_id", input_args, new_input_names); + AddNewNodeArg(graph, "backward_waited_event_id", ONNX_NAMESPACE::TensorProto_DataType_INT64, + input_args, new_input_names); input_args.insert(std::end(input_args), std::begin(record_node->MutableOutputDefs()), @@ -287,7 +293,6 @@ void ReplaceNodeArgs(std::vector& nodes, ORT_ENFORCE(node_args.size() == new_node_args.size()); for (size_t i = 0; i < node_args.size(); ++i) { // Iteration for node_args[i] and new_node_args[i]. - ORT_ENFORCE(node_args[i]->Name() != new_node_args[i]->Name()); ORT_ENFORCE(node_args[i]->Type() == new_node_args[i]->Type()); @@ -331,10 +336,10 @@ std::string AddEventBeforeNode( std::vector nodes = {node}; // Replace node_args[i] with new_node_args[i] in nodes. - ReplaceNodeArgs(nodes, node_args, new_node_args); + ReplaceNodeArgs(nodes, node_args, new_node_args); // Create node_arg for event ID. - auto event_node_arg = &CreateInt64NodeArg(graph, event_id_name); + auto event_node_arg = &CreateTypedNodeArg(graph, ONNX_NAMESPACE::TensorProto_DataType_INT64, event_id_name); // Create node which produces new_node_args from event ID and node_args. CreateBottleneckNode(graph, @@ -372,11 +377,11 @@ std::string AddEventAfterNode( std::vector consumer_nodes = graph.GetMutableConsumerNodes( node_args[i]->Name()); // Replace node_args[i] with new_node_args[i] in nodes. - ReplaceNodeArgs(consumer_nodes, {node_args[i]}, {new_node_args[i]}); + ReplaceNodeArgs(consumer_nodes, {node_args[i]}, {new_node_args[i]}); } // Create node_arg for event ID. - auto event_node_arg = &CreateInt64NodeArg(graph, event_id_name); + auto event_node_arg = &CreateTypedNodeArg(graph, ONNX_NAMESPACE::TensorProto_DataType_INT64, event_id_name); // Create node which produces new_node_args from event ID and node_args. CreateBottleneckNode(graph, @@ -395,10 +400,10 @@ Status AddForwardWaitAfterRecv( Node* comm_node, std::vector& new_input_names, std::string& event_name) { - event_name = AddEventAfterNode( + event_name = AddEventAfterNode( graph, comm_node, "WaitEvent", "forward_wait_after_recv", - "forward_wait_after_recv_event_id"); + "forward_wait_after_recv_event_id"); if (event_name.empty()) { return Status::OK(); } else { @@ -458,6 +463,37 @@ Status AddBackwardRecordBeforeSend( } } +Status SetInputsOutputsAndResolve(Graph& graph, + const std::vector& new_input_names, + const std::vector& new_output_names) { + auto fill_node_args = [&](const Graph& graph, + const std::vector& existed_node_args, + const std::vector& new_node_arg_names, + std::vector& merged_node_args) { + merged_node_args.insert(merged_node_args.end(), existed_node_args.begin(), existed_node_args.end()); + for (auto& name : new_node_arg_names) { + merged_node_args.push_back(graph.GetNodeArg(name)); + } + }; + + const std::vector& graph_inputs = graph.GetInputsIncludingInitializers(); + std::vector inputs_args_sets; + inputs_args_sets.reserve(graph_inputs.size() + new_input_names.size()); + fill_node_args(graph, graph_inputs, new_input_names, inputs_args_sets); + + const std::vector& graph_outputs = graph.GetOutputs(); + std::vector outputs_args_sets; + outputs_args_sets.reserve(graph_outputs.size() + new_output_names.size()); + fill_node_args(graph, graph_outputs, new_output_names, outputs_args_sets); + + graph.SetInputs(inputs_args_sets); + graph.SetOutputs(outputs_args_sets); + graph.SetGraphResolveNeeded(); + graph.SetGraphProtoSyncNeeded(); + + return graph.Resolve(); +} + // This function inserts WaitEvent's and RecordEvent's to the input graph for // controlling synchronization between (batch, pipeline stage)-pairs. // @@ -570,7 +606,7 @@ Status TransformGraphForPipeline( backward_waited_event_name, forward_record_output_name, backward_wait_output_name)); - + // Different stages have different patterns of Send & Recv. // For different patterns, we add different WaitEvent and Record. // @@ -691,33 +727,357 @@ Status TransformGraphForPipeline( )); } - auto fill_node_args = [&](const Graph& graph, - const std::vector& existed_node_args, - std::vector& new_node_arg_names, - std::vector& merged_node_args) { - merged_node_args.insert(merged_node_args.end(), existed_node_args.begin(), existed_node_args.end()); - for (auto& name : new_node_arg_names) { - merged_node_args.push_back(graph.GetNodeArg(name)); + ORT_RETURN_IF_ERROR(SetInputsOutputsAndResolve(graph, new_input_names, new_output_names)); + return Status::OK(); +} + +// This function is used when you want to create a scalar constant in a graph. +// It may create a NodeArg so that other Node can references its value. +// It also cerates an initializer to store its value. +template +void AddNewScalarNodeArgAndInitializer(Graph& graph, + const std::string& op_name, + onnx::TensorProto_DataType type, + T data, + std::vector& new_node_args, + std::vector& new_names) { + AddNewNodeArg(graph, op_name, type, new_node_args, new_names); + + ONNX_NAMESPACE::TensorProto proto_data; + proto_data.set_name(new_names.back()); + proto_data.set_data_type(type); + + switch (type) { + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + proto_data.add_int32_data(static_cast(data)); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + proto_data.add_int64_data(static_cast(data)); + break; + default: + ORT_THROW("pipeline partition unsupported 'type' value: ", type); + } + graph.AddInitializedTensor(proto_data); +} + +// split the graph into disconnected subgraph based on provided CutInfo +common::Status SplitGraph(Graph& graph, + std::vector split_edge_groups, + std::vector& send_nodes, + std::vector& recv_nodes) { + std::vector new_input_names; + std::vector new_output_names; + + // updated_node_args keeps track of the mapping between the original node_arg and its corresponding updated + // node_arg after send and recv node is added. As multiple partitions can happen, and a single node_arg + // can belong to different partition, updated_node_args always keeps track of the latest updated node_arg. + // Below is one example of how this works using update_node_args: + // there are three edges in graph, specified as nodeA->nodeB, nodeA->nodeC, and nodeA->nodeD. + // those edges all share the same node_arg. + // but nodeA, nodeB belong to parition0, nodeC belongs to parition1, and nodeD belongs to parition2. + // This means we need to cut edge nodeA->nodeC for the first partition and nodeA->nodeD for the second partition. + // + // During the first cut, we identify the edge nodeA->nodeC, for this edge, based on the origional node_arg, + // we create a new node_arg, called updated_node_arg. The inserted send node will take the original node_arg + // as input and the inserted recv node will take the updated_node_arg as the output. + // And we update updated_node_args with updated_node_args[original_node_arg] = updated_node_arg + // + // Now during the second cut, we need to cut the edge nodeA->nodeD. Noted that as the cut is performed in sequential, + // the second cut is performed based on the graph modified after the first cut. This means, the input node_arg for + // nodeD shouldn't come from nodeA anymore, as nodeA now residents in partition0, which is a disconnected partition. + // Instead, the input node_arg of nodeD should come from the updated version: updated_node_arg from partition1. + // By using the updated_node_args map, we can retrieve updated_node_arg from original_node_arg, and use that as the + // newly inserted send's input. Also, to keep this on going for any following cut, we create an updated_node_arg_v2, + // and update updated_node_args with updated_node_args[original_node_arg] = updated_node_arg_v2 + std::map updated_node_args; + for (size_t index = 0; index < split_edge_groups.size(); ++index) { + // each entry in split_edge_groups represents a partition cut. Each cut can contain the split of + // several edges. + auto& edgeIds = split_edge_groups[index]; + + // for each cut, record the inserted input/output args. + std::vector send_input_args; + std::vector send_output_args; + std::vector recv_input_args; + std::vector recv_output_args; + + auto cut_index_str = std::to_string(index); + // add input node_arg and initializer for send/recv + AddNewScalarNodeArgAndInitializer(graph, + "send_input_signal" + cut_index_str, + ONNX_NAMESPACE::TensorProto_DataType_BOOL, + true, /* initializer data */ + send_input_args, + new_input_names); + AddNewScalarNodeArgAndInitializer(graph, + "recv_input_signal" + cut_index_str, + ONNX_NAMESPACE::TensorProto_DataType_BOOL, + true, /* initializer data */ + recv_input_args, + new_input_names); + + AddNewScalarNodeArgAndInitializer(graph, + "send_dst_rank" + cut_index_str, + ONNX_NAMESPACE::TensorProto_DataType_INT64, + index + 1, /* initializer data */ + send_input_args, + new_input_names); + AddNewScalarNodeArgAndInitializer(graph, + "recv_src_rank" + cut_index_str, + ONNX_NAMESPACE::TensorProto_DataType_INT64, + index, /* initializer data */ + recv_input_args, + new_input_names); + // add output node_arg for send/recv + AddNewNodeArg(graph, "send_output_signal" + cut_index_str, ONNX_NAMESPACE::TensorProto_DataType_BOOL, + send_output_args, new_output_names); + AddNewNodeArg(graph, "receive_output_signal" + cut_index_str, ONNX_NAMESPACE::TensorProto_DataType_BOOL, + recv_output_args, new_output_names); + + // add attribute data for send/recv + ONNX_NAMESPACE::AttributeProto tag; + tag.set_name("tag"); + tag.set_type(ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INT); + // currently hard-coded all tag to be 0. May need to change when multiple GPU stream is used. + tag.set_i(static_cast(0)); + + ONNX_NAMESPACE::AttributeProto element_types; + element_types.set_name("element_types"); + element_types.set_type(ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS); + + // for each edge in this group, perform edge cut + for (auto& id : edgeIds) { + // find node whose output contains id.node_arg_name + auto producer_node = graph.GetMutableProducerNode(id.node_arg_name); + if (!producer_node) { + ORT_THROW("Cannot find producer node of node_arg with name: ", id.node_arg_name, ". Wrong cutting infomation."); + } + + // once we find out the producer node for id.node_arg_name, find which output index that leads + // to id.node_arg_name + int upstream_nodes_output_index{-1}; + producer_node->ForEachWithIndex( + producer_node->OutputDefs(), + [&](const NodeArg& def, size_t index) { + if (def.Name() == id.node_arg_name) { + upstream_nodes_output_index = static_cast(index); + } + return Status::OK(); + }); + + if (upstream_nodes_output_index < 0) { + ORT_THROW("Node with name: ", producer_node->Name(), + " doesn't have an output node_arg with name ", id.node_arg_name); + } + + size_t idx = static_cast(upstream_nodes_output_index); + + // original node_arg pointer from the origin graph. This serves as the key in the + // updated_node_arg map and any reference for original node_arg name + auto* original_node_arg = producer_node->MutableOutputDefs()[idx]; + + // updated node_arg pointer from previous partition. This is the new arg_node the + // current inserted send node will take as input node_arg. + auto updated_node_arg = producer_node->MutableOutputDefs()[idx]; + auto exiting_updated_node_arg = updated_node_args.find(original_node_arg); + if (exiting_updated_node_arg != updated_node_args.end()) { + updated_node_arg = exiting_updated_node_arg->second; + } + + send_input_args.push_back(updated_node_arg); + + auto dtype = original_node_arg->TypeAsProto()->tensor_type().elem_type(); + switch (dtype) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + element_types.add_ints(static_cast(1)); + break; + default: + // Assume all tensors are of type float. + // TODO: update if graph supports other data type. + ORT_THROW("pipeline partition unsupported 'type' value: ", dtype); + } + + auto& new_receive_output = CreateNodeArg(graph, updated_node_arg); + const auto old_shape = *(updated_node_arg->Shape()); + new_receive_output.SetShape(old_shape); + recv_output_args.push_back(&new_receive_output); + + // add value info for this newly added receive_output, for shape propagation + // when training this partition. + graph.AddValueInfo(&new_receive_output); + + // update updated_node_args with the newly created node_arg + updated_node_args[original_node_arg] = &new_receive_output; + + // deal with shape inference for newly added edge + auto& output_edge_name = original_node_arg->Name(); + + // deal with updating the consumer's input node_args + std::vector consumer_nodes; + if (id.consumer_nodes.has_value()) { + for(auto& consumer_node_id : id.consumer_nodes.value()){ + consumer_nodes.push_back(graph.GetMutableProducerNode(consumer_node_id)); + } + } else { + consumer_nodes = graph.GetMutableConsumerNodes(output_edge_name); + } + + for (auto consumer_node : consumer_nodes) { + for (auto& input : consumer_node->MutableInputDefs()) { + if (input->Name() == output_edge_name) { + input = &new_receive_output; + break; + } + } + } } - }; + const int num_attributes = 2; // two attributes: tag and element_types + NodeAttributes attributes; + attributes.reserve(num_attributes); + attributes[tag.name()] = tag; + attributes[element_types.name()] = element_types; - const std::vector& graph_inputs = graph.GetInputsIncludingInitializers(); - std::vector inputs_args_sets; - inputs_args_sets.reserve(graph_inputs.size() + new_input_names.size()); - fill_node_args(graph, graph_inputs, new_input_names, inputs_args_sets); + auto& send_node = graph.AddNode(graph.GenerateNodeName("Send"), + "Send", + "", + send_input_args, + send_output_args, /* output */ + &attributes, /* attribute */ + kMSDomain); - const std::vector& graph_outputs = graph.GetOutputs(); - std::vector outputs_args_sets; - outputs_args_sets.reserve(graph_outputs.size() + new_output_names.size()); - fill_node_args(graph, graph_outputs, new_output_names, outputs_args_sets); + send_nodes.push_back(&send_node); - graph.SetInputs(inputs_args_sets); - graph.SetOutputs(outputs_args_sets); + auto& recv_node = graph.AddNode(graph.GenerateNodeName("Recv"), + "Recv", + "", + recv_input_args, + recv_output_args, /* output */ + &attributes, /* attribute */ + kMSDomain); + recv_nodes.push_back(&recv_node); + } + + ORT_RETURN_IF_ERROR(SetInputsOutputsAndResolve(graph, new_input_names, new_output_names)); + return Status::OK(); +} + +Status FindAllConnectedNodes(Graph& graph, + const Node* node, + std::vector& connected_nodes, + std::set& connected_inputs, + std::set& connected_outputs) { + ORT_THROW_IF_ERROR(node->ForEachWithIndex( + node->InputDefs(), + [&](const NodeArg& node_arg, size_t /*index*/) { + if (graph.IsInputsIncludingInitializers(&node_arg)) { + connected_inputs.insert(&node_arg); + } else { + const Node* producer_node = graph.GetProducerNode(node_arg.Name()); + if (producer_node == nullptr) { + // got nullptr as producer node. This could be because the input is a constant op which will be optimized + // away. Print out this information and continue. + LOGS_DEFAULT(WARNING) << "Cannot find producer node for node_arg: " << node_arg.Name() << ". Skipping this node."; + } else { + connected_nodes.push_back(producer_node); + } + } + return Status::OK(); + })); + + ORT_THROW_IF_ERROR(node->ForEachWithIndex( + node->OutputDefs(), + [&](const NodeArg& node_arg, size_t /*index*/) { + if (!graph.IsOutput(&node_arg)) { + std::vector consumer_nodes = graph.GetConsumerNodes(node_arg.Name()); + connected_nodes.insert(std::end(connected_nodes), consumer_nodes.begin(), consumer_nodes.end()); + + } else { + connected_outputs.insert(&node_arg); + } + return Status::OK(); + })); + return Status::OK(); +} + +// traverse the graph from start_node to get the set of nodes contains in this disconnected subgraph +common::Status GenerateSubgraph(Graph& graph, const Node* start_node) { + std::queue node_queue; + node_queue.push(start_node); + + std::set visited_nodes; + std::set visited_inputs; + std::set visited_outputs; + + // BFS graph traverse + while (!node_queue.empty()) { + auto node = node_queue.front(); + node_queue.pop(); + if (visited_nodes.count(node) == 0) { + visited_nodes.insert(node); + std::vector connected_nodes; + ORT_THROW_IF_ERROR(FindAllConnectedNodes(graph, node, connected_nodes, visited_inputs, visited_outputs)); + + for (auto n : connected_nodes) { + ORT_ENFORCE(n!=nullptr, "Found nullptr in searching for connected nodes"); + node_queue.push(n); + } + } + } + std::set visited_node_index; + for (auto n : visited_nodes) { + visited_node_index.insert(n->Index()); + } + + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + // reverse iterate the nodes in tolopogical order, and delete those not visited + for (auto it = node_topology_list.rbegin(); it != node_topology_list.rend(); it++){ + if (visited_node_index.count(*it)==0){ + graph.RemoveNode(*it); + } + } + + // update the grah with only visited inputs and outputs + graph.SetInputs({visited_inputs.begin(), visited_inputs.end()}); + graph.SetOutputs({visited_outputs.begin(), visited_outputs.end()}); graph.SetGraphResolveNeeded(); graph.SetGraphProtoSyncNeeded(); return graph.Resolve(); } +Status ApplyPipelinePartitionToMainGraph( + Graph& graph, + const std::vector& cut_info, + size_t pipeline_stage_id, + size_t num_pipeline_stage) { + size_t split_count = cut_info.size(); + + if (num_pipeline_stage != split_count + 1) { + ORT_THROW("Wrong pipeline partition cutting info. Total pipeline stage number is ", + num_pipeline_stage, + ", cut info length is: ", + split_count); + } + + std::vector send_nodes, recv_nodes; + send_nodes.reserve(split_count); + recv_nodes.reserve(split_count); + ORT_RETURN_IF_ERROR(SplitGraph(graph, cut_info, send_nodes, recv_nodes)); + + if (send_nodes.size() != split_count || recv_nodes.size() != split_count) { + ORT_THROW("Split error: not all cut has Send and Recv inserted. Send node count: ", + send_nodes.size(), ", Recv node count: ", recv_nodes.size(), ", split count: ", split_count); + } + + if (pipeline_stage_id < split_count) { + ORT_RETURN_IF_ERROR(GenerateSubgraph(graph, send_nodes[pipeline_stage_id])); + } else { + ORT_RETURN_IF_ERROR(GenerateSubgraph(graph, recv_nodes.back())); + } + return Status::OK(); +} + } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/pipeline_transformer.h b/orttraining/orttraining/core/graph/pipeline_transformer.h index f2bc3d49f8..bbc424439b 100644 --- a/orttraining/orttraining/core/graph/pipeline_transformer.h +++ b/orttraining/orttraining/core/graph/pipeline_transformer.h @@ -4,12 +4,14 @@ #pragma once #include "core/graph/graph.h" +#include "orttraining/core/session/training_session.h" namespace onnxruntime { namespace training { void GetPipelineSendOutput(const Graph& graph, std::string& loss_name); -common::Status TransformGraphForPipeline( + +Status TransformGraphForPipeline( Graph& graph, std::string& forward_waited_event_name, std::string& forward_recorded_event_name, @@ -24,5 +26,10 @@ common::Status TransformGraphForPipeline( std::string& backward_waited_event_after_recv_name, std::string& backward_recorded_event_before_send_name); +Status ApplyPipelinePartitionToMainGraph( + Graph& graph, + const std::vector& cut_info, + size_t pipeline_stage_id, + size_t num_pipeline_stage); } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index cc99d2c523..67201f6a57 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -141,6 +141,15 @@ Status TrainingSession::ConfigureForTraining( excluded_initializers.erase(weight_name_to_not_train); } + if (config.pipeline_config.has_value() && config.pipeline_config.value().do_partition) { + // Apply online pipeline partition to graph obj. This needs to be done first before any graph + // transportation which may alter node_arg and invalidate cut_list info from the original graph. + ORT_RETURN_IF_ERROR(ApplyPipelinePartitionToMainGraph(model_->MainGraph(), + config.pipeline_config.value().cut_list, + config.distributed_config.world_rank, + config.distributed_config.world_size)); + } + ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(excluded_initializers)); is_mixed_precision_enabled_ = config.mixed_precision_config.has_value(); @@ -233,7 +242,7 @@ Status TrainingSession::ConfigureForTraining( } pipeline_result.fetch_names.push_back(name); } - pipeline_result.pipeline_stage_id = config.distributed_config.world_rank / + pipeline_result.pipeline_stage_id = config.distributed_config.world_rank / (config.distributed_config.data_parallel_size * config.distributed_config.horizontal_parallel_size); config_result.pipeline_config_result = pipeline_result; } @@ -307,13 +316,19 @@ Status TrainingSession::ConfigureForTraining( tensorboard_config.histogram_node_names, tensorboard_config.norm_node_names, tensorboard_config.dump_convergence_metrics)); } - + // add GIST encoding if (config.gist_config.has_value()) { ORT_RETURN_IF_ERROR(AddGistEncoding()); } - if (IsRootNode(config) && config.model_with_training_graph_path.has_value()) { + // If the current node is in rank0 or if the current session is running pipeline (in which case different rank would + // store different model partition), and if model_with_training_graph_path is specified, save the model. + // Note: in the pipeline case, different ranks may resident in the same node. This could lead to a potential write + // conflict. It is user's responsibility to make sure different rank is passed in with different + // model_with_training_graph_path value. + if ((IsRootNode(config) || config.pipeline_config.has_value()) + && config.model_with_training_graph_path.has_value()) { ORT_IGNORE_RETURN_VALUE(Save( config.model_with_training_graph_path.value(), SaveOption::NO_RELOAD)); } diff --git a/orttraining/orttraining/core/session/training_session.h b/orttraining/orttraining/core/session/training_session.h index 0d54ab9842..4dfdd5f3d8 100644 --- a/orttraining/orttraining/core/session/training_session.h +++ b/orttraining/orttraining/core/session/training_session.h @@ -135,6 +135,28 @@ class TrainingSession : public InferenceSession { // If not provided, no optimizer is added. optional optimizer_config{}; + // struct to describe a specific edge. An edge is not the same as a node_arg. Edge represents a connection between two operators. + // For example, an operator A's output tensor T is connecting to another operator B's input, then this constructs + // an edge from A to B. If A's output tensor T has multiple consumers, i.e. it's fed into multiple operators' inputs, + // there would be multiple edges, each from A, to a consumer operator. + // CutEdge information is used in pipeline online partition tool to identify which edge to cut to make the + // corresponding partition. + struct CutEdge { + std::string node_arg_name; + optional> consumer_nodes; + + // If the edge is unique, i.e. only have one consumer node, or all the edges + // with the same node_arg_name needs to be cut, specify the node_arg_name + // suffices. + CutEdge(std::string edge) : node_arg_name(edge){}; + // If the edges with same node_arg_name belongs to different cut, i.e. some of its + // consumer node belongs to one partition, and some belongs to another, specify + // the consumer node names which you want to perform the cut on. + CutEdge(std::string edge, std::vector nodes) : node_arg_name(edge), consumer_nodes(nodes){}; + }; + // CutInfo is a group of CutEdges that describes a specific cut that composed of splitting those edges. + typedef std::vector CutInfo; + struct PipelineConfiguration { // If model partition happens outside ORT, this flag should be false. // Otherwise, use true to trigger ORT's pipeline partition. @@ -142,7 +164,9 @@ class TrainingSession : public InferenceSession { // Tensors to fetch as specified by the user. // Each pipeline stage should pick up some strings from this field.. std::vector fetch_names; - // [TODO] Add cut information. + // cut_list contains the list of CutInfo to make the graph partitions. + // cut_list[i] contains the CutInfo to make the partition between stage i and stage i+1 + std::vector cut_list; }; // If pipeline is enabled, this field's has_value() returns true. @@ -329,7 +353,7 @@ class TrainingSession : public InferenceSession { // Insert operators for running pipeline and return event tensor names. // For an intermediate pipeline stage, its original computation is // - // Recv --> Forward --> Send --> + // Recv --> Forward --> Send --> // Recv --> Backward --> Send // // After this function, the resulted computation is @@ -340,7 +364,7 @@ class TrainingSession : public InferenceSession { // As you can see, some event operators are inserted. For each event operator, its dependent // event tensor name is written to an input references, for example, "forward_waited_event_name". // - // This function assumes that + // This function assumes that // 1. Only one Recv and only one Send present in forward pass. // 2. Only one Recv and only one Send present in backward pass. // 3. Backward operators' descriptions are all "Backward pass". This assumption is used to diff --git a/orttraining/orttraining/models/bert/main.cc b/orttraining/orttraining/models/bert/main.cc index 1e5a7b3e08..cd719efee9 100644 --- a/orttraining/orttraining/models/bert/main.cc +++ b/orttraining/orttraining/models/bert/main.cc @@ -155,6 +155,11 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet ("horizontal_parallel_size", "Horizontal model parallel group size.", cxxopts::value()->default_value("1")) ("pipeline_parallel_size", "Number of pipeline stages.", cxxopts::value()->default_value("1")) ("pipeline_stage_paths", "Specify the forward ONNX files for pipeline evaluation.", cxxopts::value>()->default_value("")) + ("cut_group_info", "Specify the cutting info for graph partition (pipeline only). An example of a cut_group_info of " + "size two is: 1393:407-1463/1585/1707,2369:407-2439/2561/2683. Here, the cut info is split by ',', with the first " + "cut_info equal to 1393:407-1463/1585/1707, and second cut_info equal to 2369:407-2439/2561/2683. Each CutEdge is " + "seperated by ':'. If consumer nodes need to be specified, specify them after producer node with a '-' delimiter and " + "separate each consumer node with a '/'. ", cxxopts::value>()->default_value("")) ("enable_grad_norm_clip", "Specify whether to enable gradient clipping for optimizers.", cxxopts::value()->default_value("true")); options @@ -379,6 +384,58 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet // Backward pass and optimizer nodes are implicitly generated by ORT. params.pipeline_stage_paths = flags["pipeline_stage_paths"].as>(); + // If user doesn't provide partitioned model files, a cut list should be provided for ORT to do partition + // online. If the pipeline contains n stages, the cut list should be of length (n-1), in order to cut the + // graph into n partitions. + if (params.pipeline_parallel_size > 1 && params.pipeline_stage_paths.empty()) { + auto cut_info_groups = flags["cut_group_info"].as>(); + + ORT_RETURN_IF_NOT(static_cast(cut_info_groups.size() + 1) == params.pipeline_parallel_size, + "cut_info length plus one must match pipeline parallel size"); + + auto process_with_delimiter = [](std::string& input_str, const std::string& delimiter) { + std::vector result; + size_t pos = 0; + std::string token; + while ((pos = input_str.find(delimiter)) != std::string::npos) { + token = input_str.substr(0, pos); + result.emplace_back(token); + input_str.erase(0, pos + delimiter.length()); + } + // push the last split of substring into result. + result.emplace_back(input_str); + return result; + }; + + auto process_cut_info = [&](std::string& cut_info_string) { + TrainingSession::TrainingConfiguration::CutInfo cut_info; + const std::string edge_delimiter = ":"; + const std::string consumer_delimiter = "/"; + const std::string producer_consumer_delimiter = "-"; + + auto cut_edges = process_with_delimiter(cut_info_string, edge_delimiter); + for (auto& cut_edge : cut_edges) { + auto process_edge = process_with_delimiter(cut_edge, producer_consumer_delimiter); + if (process_edge.size() == 1) { + TrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0]}; + cut_info.emplace_back(edge); + } else { + ORT_ENFORCE(process_edge.size() == 2); + auto consumer_list = process_with_delimiter(process_edge[1], consumer_delimiter); + + TrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0], consumer_list}; + cut_info.emplace_back(edge); + } + } + return cut_info; + }; + + for (auto& cut_info : cut_info_groups) { + TrainingSession::TrainingConfiguration::CutInfo cut = process_cut_info(cut_info); + params.pipeline_partition_cut_list.emplace_back(cut); + } + } + int64_t seed = flags["seed"].as(); if (params.horizontal_parallel_size > 1 && seed <= 0) { seed = 8211; // Megatron needs a random seed. diff --git a/orttraining/orttraining/models/runner/training_runner.cc b/orttraining/orttraining/models/runner/training_runner.cc index 805e9d53bd..1870237bf3 100644 --- a/orttraining/orttraining/models/runner/training_runner.cc +++ b/orttraining/orttraining/models/runner/training_runner.cc @@ -158,6 +158,7 @@ Status TrainingRunner::Initialize() { // the session already loads a pipeline stage. pipe.do_partition = params_.pipeline_stage_paths.empty() ? true : false; pipe.fetch_names = params_.fetch_names; + pipe.cut_list = params_.pipeline_partition_cut_list; // Do not assign value to config.pipeline_config if pipeline is not used. config.pipeline_config = pipe; } @@ -321,7 +322,7 @@ Status TrainingRunner::Initialize() { return Status::OK(); } -Status TrainingRunner::Run(IDataLoader* training_data_loader, IDataLoader* test_data_loader, +Status TrainingRunner::Run(IDataLoader* training_data_loader, IDataLoader* test_data_loader, const MapStringToString& mapped_dimensions) { if (params_.mpi_context.world_rank == 0 && !params_.model_actual_running_graph_path.empty()) { session_.Save(params_.model_actual_running_graph_path, TrainingSession::SaveOption::NO_RELOAD); @@ -528,7 +529,7 @@ Status TrainingRunner::PrepareFetchNamesAndFetches(const SessionMode mode, const auto& allowed_fetch_names = pipeline_context_.fetch_names; if (mode == ModelUpdateStep) { - // Set up tensor to be fetched when doing model update. + // Set up tensor to be fetched when doing model update. if (params_.pipeline_parallel_size > 1) { // If pipeline is used, we need to filter out fetches which are not in this pipeline stage. @@ -557,7 +558,7 @@ Status TrainingRunner::PrepareFetchNamesAndFetches(const SessionMode mode, } } } else if (mode == GradientAccumulateStep) { - // Set up tensor to be fetched when doing gradient accumulation. + // Set up tensor to be fetched when doing gradient accumulation. if (params_.gradient_accumulation_steps > 1) { auto it = opt_graph_outputs_.find(OptimizerOutputKey::GradientAccumulation); @@ -583,7 +584,7 @@ Status TrainingRunner::PrepareFetchNamesAndFetches(const SessionMode mode, } } else if (mode == EvaluateStep) { // Set up tensor to be fetched when doing model evaluation. - // Ideally, this path should not fetch optimizer and gradient accumulation. + // Ideally, this path should not fetch optimizer and gradient accumulation. // This path may fetch predicted scores, loss value, and so on. if (params_.pipeline_parallel_size > 1) { @@ -658,7 +659,7 @@ Status TrainingRunner::RunWithUpdate(VectorString& feed_names, } } - // Wait all workers to finish this around of pipeline parallism. + // Wait all workers to finish this around of pipeline parallism. // The last batch in a pipeline collects gradient and update the model. pipeline_worker_pool_.JoinAll(); @@ -802,7 +803,8 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad fetch_names, fetches); RunWithoutUpdate(feed_names, fetch_names, feeds, - gradient_accumulation_step_count); + gradient_accumulation_step_count); + } auto end = std::chrono::high_resolution_clock::now(); @@ -880,7 +882,7 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad const size_t peak_workingset_size = perftest::utils::GetPeakWorkingSetSize(); ORT_RETURN_IF_ERROR(Env::Default().CreateFolder(params_.perf_output_dir)); // saving json file - ORT_RETURN_IF_ERROR(SavePerfMetrics(number_of_batches, gradient_accumulation_step_count, weight_update_steps, + ORT_RETURN_IF_ERROR(SavePerfMetrics(number_of_batches, gradient_accumulation_step_count, weight_update_steps, total_time, avg_time_per_batch, throughput, stabilized_throughput, mapped_dimensions, average_cpu_usage, peak_workingset_size)); } @@ -901,17 +903,17 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad Status TrainingRunner::SavePerfMetrics(const size_t number_of_batches, const size_t gradient_accumulation_steps, const size_t weight_update_steps, const double total_time, const double avg_time_per_batch, const double throughput, const double stabilized_throughput, - const MapStringToString& mapped_dimensions, + const MapStringToString& mapped_dimensions, const short average_cpu_usage, const size_t peak_workingset_size) { // populate metrics for reporting json perf_metrics; - perf_metrics["Model"] = params_.model_type; + perf_metrics["Model"] = params_.model_type; // loop thru the mapped_dimensions and put it in json sub-structure std::string seq_len; for (auto const& it : mapped_dimensions) { if (it.first == "SeqLen") { - seq_len = it.second; + seq_len = it.second; } perf_metrics["DerivedProperties"][it.first] = it.second; } @@ -929,7 +931,7 @@ Status TrainingRunner::SavePerfMetrics(const size_t number_of_batches, const siz std::string optimizer = params_.training_optimizer_name; std::size_t pos = optimizer.find("Optimizer"); - if (pos != std::string::npos) + if (pos != std::string::npos) optimizer = optimizer.substr(0, pos); perf_metrics["Optimizer"] = optimizer; @@ -948,7 +950,7 @@ Status TrainingRunner::SavePerfMetrics(const size_t number_of_batches, const siz // // we will get date/time and commitId in post-run pipeline - // + // // populate other basic params for bookkeeping - add more as needed json bookkeeping_params; @@ -962,7 +964,7 @@ Status TrainingRunner::SavePerfMetrics(const size_t number_of_batches, const siz perf_metrics["RunConfig"] = bookkeeping_params.dump(); // serialize the params as json string - std::string json_string = perf_metrics.dump(); + std::string json_string = perf_metrics.dump(); // write to a file - the next task in CI will pick up all files with the same prefix const PathString perf_metrics_path = @@ -1072,7 +1074,7 @@ Status TrainingRunner::Evaluate(InferenceSession& session, IDataLoader& data_loa fetch_names, &fetches)); - // Assume that user-specified fetches are avaliable only on the last pipeline stage. + // Assume that user-specified fetches are avaliable only on the last pipeline stage. // When there is no pipeline, all pipeline_context_.pipeline_stage_id should be 0 and // params_.pipeline_parallel_size is 1. Thus, the following condition is always true if there // is no pipeline. diff --git a/orttraining/orttraining/models/runner/training_runner.h b/orttraining/orttraining/models/runner/training_runner.h index 2ebad08a90..5c4371f1e8 100644 --- a/orttraining/orttraining/models/runner/training_runner.h +++ b/orttraining/orttraining/models/runner/training_runner.h @@ -156,8 +156,11 @@ class TrainingRunner { // pipeline_parallel_size > 1 means pipeline is enabled. // pipeline_parallel_size == 1 means pipeline is disabled. int pipeline_parallel_size = 1; + // pipeline partition information to do online-partition. If the graph is + // pre-partitioned, no need to fill this value. + std::vector pipeline_partition_cut_list; // model_paths[i] is the name of the pipeline stage for i-th process. - // The i-th file is run by the i-th MPI rank. + // The i-th file is run by the i-th MPI rank. // If model_paths is not empty, model partition transformation may not be internally invoked. VectorString pipeline_stage_paths; // Enable gradient clipping. @@ -169,7 +172,7 @@ class TrainingRunner { common::Status Initialize(); - common::Status Run(IDataLoader* training_data_loader, IDataLoader* test_data_loader, + common::Status Run(IDataLoader* training_data_loader, IDataLoader* test_data_loader, const MapStringToString& mapped_dimensions = {}); common::Status EndTraining(IDataLoader* data_loader); @@ -196,12 +199,12 @@ class TrainingRunner { Status RunWithUpdate(VectorString& feed_names, VectorString& fetch_names, std::vector& feeds, - std::vector& fetches); + std::vector& fetches); Status RunWithoutUpdate(VectorString& feed_names, VectorString& fetch_names, std::vector& feeds, - size_t& gradient_accumulation_step_count); - Status TrainingLoop(IDataLoader& training_data_loader, IDataLoader* test_data_loader, + size_t& gradient_accumulation_step_count); + Status TrainingLoop(IDataLoader& training_data_loader, IDataLoader* test_data_loader, const MapStringToString& mapped_dimensions); Status Evaluate(InferenceSession& session, IDataLoader& data_loader); @@ -230,7 +233,7 @@ class TrainingRunner { AllocatorPtr input_allocator_; std::unique_ptr checkpoint_registry_; - + // Pipeline fields are valid only if params_.pipeline_parallel_size > 1. // Information for running pipeline. pipeline::PipelineContext pipeline_context_; diff --git a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc index 1cbf5800aa..b2f888557c 100644 --- a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc +++ b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc @@ -998,7 +998,7 @@ class PipelineBatchPlanner { }; void RetrieveEventOperators( - Graph& graph, + Graph& graph, Node** forward_wait_before_recv, Node** forward_wait_after_recv, Node** forward_record_before_send, @@ -1065,7 +1065,7 @@ void RetrieveEventOperators( } void RetrieveSendRecvOperators( - Graph& graph, + Graph& graph, Node** forward_recv, Node** forward_send, Node** backward_recv, @@ -1089,9 +1089,9 @@ void RetrieveSendRecvOperators( if (is_backward(node)) { // backward_send can only be assigned one valid pointer. // If it is assigned more than once, it means we have multiple - // Send in backward pass and therefore our assumption doesn't hold. + // Send in backward pass and therefore our assumption doesn't hold. // This check ensure that only we only update *backward_send when - // its value is NULL and guards our one-Recv assumption. + // its value is NULL and guards our one-Recv assumption. ASSERT_TRUE(!(*backward_send)); *backward_send = &node; } else { @@ -1115,6 +1115,46 @@ void RetrieveSendRecvOperators( } } +TEST(GradientGraphBuilderTest, PipelineOnlinePartition) { + auto model_uri = ORIGINAL_MODEL_PATH; + + TrainingSession::TrainingConfiguration::PipelineConfiguration pipe{}; + pipe.do_partition = true; + + // evenly cut the MLP model in 3 partitions + TrainingSession::TrainingConfiguration::CutInfo cut0 = {TrainingSession::TrainingConfiguration::CutEdge("T3")}; + TrainingSession::TrainingConfiguration::CutInfo cut1 = {TrainingSession::TrainingConfiguration::CutEdge("T6")}; + pipe.cut_list.emplace_back(cut0); + pipe.cut_list.emplace_back(cut1); + + for (int i = 0; i < 3; ++i) { +#ifdef _WIN32 + auto surfix = std::to_wstring(i); +#else + auto surfix = std::to_string(i); +#endif + PathString output_file = ORT_TSTR("pipeline_partition_") + surfix + ORT_TSTR("_back.onnx"); + + auto config = MakeBasicTrainingConfig(); + config.pipeline_config = pipe; + config.distributed_config.world_rank = i; + config.distributed_config.world_size = 3; + config.distributed_config.local_rank = i; + config.distributed_config.local_size = 3; + config.distributed_config.data_parallel_size = 1; + config.distributed_config.horizontal_parallel_size = 1; + config.distributed_config.pipeline_parallel_size = 3; + config.model_with_training_graph_path = output_file; + + PathString backprop_model_file; + ASSERT_STATUS_OK(BuildBackPropGraph(model_uri, config, backprop_model_file)); + + std::shared_ptr model; + // Ensure the partitioned model load. + ASSERT_STATUS_OK(Model::Load(backprop_model_file, model, nullptr, DefaultLoggingManager().DefaultLogger())); + } +} + // verify pipeline config can load and gradient graph can construct. TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) { PathString filename_base = ORT_TSTR("testdata/test_training_model_"); @@ -1197,7 +1237,7 @@ TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) { Node* backward_send{nullptr}; RetrieveSendRecvOperators( - graph, + graph, &forward_recv, &forward_send, &backward_recv,