Add pipeline online partition logic for pipeline (#3996)

* online partition

* fix when multiple consumer nodes is in cut info

* fix windows build

* address feedback

* adding test

* feedback

* address feedback

* add parser for cut edge

* windows build
This commit is contained in:
Xueyun Zhu 2020-05-26 17:44:09 -07:00 committed by GitHub
parent 0d8abc1a99
commit 633008b5ef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 612 additions and 77 deletions

View file

@ -525,6 +525,11 @@ class Graph {
return graph_inputs_including_initializers_;
}
/** Return true if "node_arg" is a input or an initializer. Otherwise, returns false. */
bool IsInputsIncludingInitializers(const NodeArg* node_arg) const noexcept{
return std::find(graph_inputs_including_initializers_.begin(), graph_inputs_including_initializers_.end(), node_arg) != graph_inputs_including_initializers_.end();
}
/** Gets the Graph inputs that are initializers
These are overridable initializers. This is a difference between
graph_inputs_including_initializers_ and graph_inputs_excluding_initializers_
@ -537,6 +542,10 @@ class Graph {
@remarks Contains no nullptr values.*/
const std::vector<const NodeArg*>& GetOutputs() const noexcept { return graph_outputs_; }
bool IsOutput(const NodeArg* node_arg) const noexcept{
return std::find(graph_outputs_.begin(), graph_outputs_.end(), node_arg) != graph_outputs_.end();
}
/** Returns a vector with the indexes of the outputs of the given Node that are also Graph outputs. */
std::vector<int> GetNodeOutputsInGraphOutputs(const Node& node) const {
int output_idx = 0;
@ -557,6 +566,8 @@ class Graph {
@remarks Contains no nullptr values. */
const std::vector<const NodeArg*>& GetValueInfo() const noexcept;
void AddValueInfo(const NodeArg* new_value_info);
/** Gets the Node with the specified node index.
@returns Node instance if found. nullptr if node_index is invalid or node has been freed.
*/

View file

@ -12,7 +12,7 @@ namespace onnxruntime {
class Node;
/**
Class to filter out null entries from either a vector of unique_ptr<Node> or a vector of [const] Node* and
Class to filter out null entries from either a vector of unique_ptr<Node> or a vector of [const] Node* and
provide an iterator interface that returns [const] Node& for the valid entries.
*/
template <typename TNodesContainer>
@ -29,6 +29,7 @@ class ValidNodes {
using ConstNodeIterator = NodeIterator<typename TNodesContainer::const_iterator>;
using MutableNodeIterator = NodeIterator<typename TNodesContainer::iterator>;
using ConstReverseNodeIterator = NodeIterator<typename TNodesContainer::const_reverse_iterator>;
ConstNodeIterator cbegin() const noexcept {
return {nodes_.cbegin(), nodes_.cend()};
@ -46,6 +47,14 @@ class ValidNodes {
return cend();
}
ConstReverseNodeIterator rbegin() const noexcept {
return {nodes_.crbegin(), nodes_.crend()};
}
ConstReverseNodeIterator rend() const noexcept {
return {nodes_.crend(), nodes_.crend()};
}
MutableNodeIterator begin() noexcept {
return {nodes_.begin(), nodes_.end()};
}
@ -56,10 +65,10 @@ class ValidNodes {
bool empty() const noexcept { return nodes_.empty(); }
/**
/**
@class NodeIterator
Iterator to provide const and non-const access to valid Node instances in a Graph.
@remarks Skips invalid nodes.
@remarks Skips invalid nodes.
*/
template <typename TIterator>
class NodeIterator {
@ -130,7 +139,7 @@ class ValidNodes {
};
/**
Class that provides iteration over all valid nodes in the Graph.
Class that provides iteration over all valid nodes in the Graph.
*/
class GraphNodes : public ValidNodes<std::vector<std::unique_ptr<Node>>> {
public:

View file

@ -2337,6 +2337,13 @@ const std::vector<const NodeArg*>& Graph::GetValueInfo() const noexcept {
return value_info_;
}
void Graph::AddValueInfo(const NodeArg* new_value_info){
for(const auto* info : value_info_){
ORT_ENFORCE(info->Name() != new_value_info->Name(), "Error: trying to add an existing value info.");
}
value_info_.push_back(new_value_info);
}
std::vector<NodeArg*> Graph::CreateNodeArgs(const google::protobuf::RepeatedPtrField<std::string>& names,
const ArgNameToTypeMap& name_to_type_map) {
const auto name_to_type_map_end = name_to_type_map.end();

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "orttraining/core/graph/pipeline_transformer.h"
#include <queue>
using namespace onnxruntime::common;
@ -22,21 +23,22 @@ bool IsBackward(Node& node) {
return (node.Description() == "Backward pass");
}
NodeArg& CreateInt64NodeArg(Graph& graph, const std::string& name) {
NodeArg& CreateTypedNodeArg(Graph& graph, onnx::TensorProto_DataType type, const std::string& name) {
ONNX_NAMESPACE::TypeProto type_proto;
type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
type_proto.mutable_tensor_type()->set_elem_type(type);
auto actual_name = graph.GenerateNodeArgName(name);
auto& node_arg = graph.GetOrCreateNodeArg(actual_name, &type_proto);
return node_arg;
}
void AddInputEvent(Graph& graph,
const std::string& event_name,
std::vector<NodeArg*>& input_args,
std::vector<std::string>& new_input_names) {
auto& event_id = CreateInt64NodeArg(graph, event_name);
new_input_names.push_back(event_id.Name());
input_args.push_back(&event_id);
void AddNewNodeArg(Graph& graph,
const std::string& op_name,
onnx::TensorProto_DataType type,
std::vector<NodeArg*>& new_node_args,
std::vector<std::string>& new_names) {
auto& new_node_arg = CreateTypedNodeArg(graph, type, op_name);
new_names.push_back(new_node_arg.Name());
new_node_args.push_back(&new_node_arg);
}
// gradient graph can contain some dangling leaf nodes. Add them all to WaitEvent
@ -112,7 +114,7 @@ Node& CreateBottleneckNode(Graph& graph,
nullptr /* assume all bottleneck node have no attributes */,
kMSDomain);
}
Node* AddBackwardRecord(Graph& graph,
Node* backward_send,
std::vector<std::string>& new_input_names,
@ -120,7 +122,8 @@ Node* AddBackwardRecord(Graph& graph,
std::string &event_id_tensor_name,
std::string &output_tensor_name) {
std::vector<NodeArg*> input_args;
AddInputEvent(graph, "backward_recorded_event_id", input_args, new_input_names);
AddNewNodeArg(graph, "backward_recorded_event_id", ONNX_NAMESPACE::TensorProto_DataType_INT64,
input_args, new_input_names);
std::vector<NodeArg*> output_args{};
if (backward_send) {
@ -175,10 +178,11 @@ Node* AddForwardWait(Graph& graph,
std::vector<NodeArg*> input_args;
std::vector<NodeArg*> output_args;
AddInputEvent(graph, "forward_waited_event_id", input_args, new_input_names);
AddNewNodeArg(graph, "forward_waited_event_id", ONNX_NAMESPACE::TensorProto_DataType_INT64,
input_args, new_input_names);
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();
if (graph_inputs.size() == 0){
if (graph_inputs.size() == 0) {
ORT_THROW("Graph ", graph.Name(), " doesn't have any inputs.");
}
@ -237,7 +241,8 @@ Status AddOrSkipForwardRecordBackwardWait(Graph& graph,
{
std::vector<NodeArg*> input_args;
std::vector<NodeArg*> output_args;
AddInputEvent(graph, "forward_recorded_event_id", input_args, new_input_names);
AddNewNodeArg(graph, "forward_recorded_event_id", ONNX_NAMESPACE::TensorProto_DataType_INT64,
input_args, new_input_names);
// Add send forward op's output as record op's input and output
for (auto& output : forward_send->MutableOutputDefs()) {
@ -258,7 +263,8 @@ Status AddOrSkipForwardRecordBackwardWait(Graph& graph,
{
std::vector<NodeArg*> input_args;
std::vector<NodeArg*> output_args;
AddInputEvent(graph, "backward_waited_event_id", input_args, new_input_names);
AddNewNodeArg(graph, "backward_waited_event_id", ONNX_NAMESPACE::TensorProto_DataType_INT64,
input_args, new_input_names);
input_args.insert(std::end(input_args),
std::begin(record_node->MutableOutputDefs()),
@ -287,7 +293,6 @@ void ReplaceNodeArgs(std::vector<Node*>& nodes,
ORT_ENFORCE(node_args.size() == new_node_args.size());
for (size_t i = 0; i < node_args.size(); ++i) {
// Iteration for node_args[i] and new_node_args[i].
ORT_ENFORCE(node_args[i]->Name() != new_node_args[i]->Name());
ORT_ENFORCE(node_args[i]->Type() == new_node_args[i]->Type());
@ -331,10 +336,10 @@ std::string AddEventBeforeNode(
std::vector<Node*> nodes = {node};
// Replace node_args[i] with new_node_args[i] in nodes.
ReplaceNodeArgs(nodes, node_args, new_node_args);
ReplaceNodeArgs(nodes, node_args, new_node_args);
// Create node_arg for event ID.
auto event_node_arg = &CreateInt64NodeArg(graph, event_id_name);
auto event_node_arg = &CreateTypedNodeArg(graph, ONNX_NAMESPACE::TensorProto_DataType_INT64, event_id_name);
// Create node which produces new_node_args from event ID and node_args.
CreateBottleneckNode(graph,
@ -372,11 +377,11 @@ std::string AddEventAfterNode(
std::vector<Node*> consumer_nodes = graph.GetMutableConsumerNodes(
node_args[i]->Name());
// Replace node_args[i] with new_node_args[i] in nodes.
ReplaceNodeArgs(consumer_nodes, {node_args[i]}, {new_node_args[i]});
ReplaceNodeArgs(consumer_nodes, {node_args[i]}, {new_node_args[i]});
}
// Create node_arg for event ID.
auto event_node_arg = &CreateInt64NodeArg(graph, event_id_name);
auto event_node_arg = &CreateTypedNodeArg(graph, ONNX_NAMESPACE::TensorProto_DataType_INT64, event_id_name);
// Create node which produces new_node_args from event ID and node_args.
CreateBottleneckNode(graph,
@ -395,10 +400,10 @@ Status AddForwardWaitAfterRecv(
Node* comm_node,
std::vector<std::string>& new_input_names,
std::string& event_name) {
event_name = AddEventAfterNode(
event_name = AddEventAfterNode(
graph, comm_node,
"WaitEvent", "forward_wait_after_recv",
"forward_wait_after_recv_event_id");
"forward_wait_after_recv_event_id");
if (event_name.empty()) {
return Status::OK();
} else {
@ -458,6 +463,37 @@ Status AddBackwardRecordBeforeSend(
}
}
Status SetInputsOutputsAndResolve(Graph& graph,
const std::vector<std::string>& new_input_names,
const std::vector<std::string>& new_output_names) {
auto fill_node_args = [&](const Graph& graph,
const std::vector<const NodeArg*>& existed_node_args,
const std::vector<std::string>& new_node_arg_names,
std::vector<const NodeArg*>& merged_node_args) {
merged_node_args.insert(merged_node_args.end(), existed_node_args.begin(), existed_node_args.end());
for (auto& name : new_node_arg_names) {
merged_node_args.push_back(graph.GetNodeArg(name));
}
};
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();
std::vector<const NodeArg*> inputs_args_sets;
inputs_args_sets.reserve(graph_inputs.size() + new_input_names.size());
fill_node_args(graph, graph_inputs, new_input_names, inputs_args_sets);
const std::vector<const NodeArg*>& graph_outputs = graph.GetOutputs();
std::vector<const NodeArg*> outputs_args_sets;
outputs_args_sets.reserve(graph_outputs.size() + new_output_names.size());
fill_node_args(graph, graph_outputs, new_output_names, outputs_args_sets);
graph.SetInputs(inputs_args_sets);
graph.SetOutputs(outputs_args_sets);
graph.SetGraphResolveNeeded();
graph.SetGraphProtoSyncNeeded();
return graph.Resolve();
}
// This function inserts WaitEvent's and RecordEvent's to the input graph for
// controlling synchronization between (batch, pipeline stage)-pairs.
//
@ -570,7 +606,7 @@ Status TransformGraphForPipeline(
backward_waited_event_name,
forward_record_output_name,
backward_wait_output_name));
// Different stages have different patterns of Send & Recv.
// For different patterns, we add different WaitEvent and Record.
//
@ -691,33 +727,357 @@ Status TransformGraphForPipeline(
));
}
auto fill_node_args = [&](const Graph& graph,
const std::vector<const NodeArg*>& existed_node_args,
std::vector<std::string>& new_node_arg_names,
std::vector<const NodeArg*>& merged_node_args) {
merged_node_args.insert(merged_node_args.end(), existed_node_args.begin(), existed_node_args.end());
for (auto& name : new_node_arg_names) {
merged_node_args.push_back(graph.GetNodeArg(name));
ORT_RETURN_IF_ERROR(SetInputsOutputsAndResolve(graph, new_input_names, new_output_names));
return Status::OK();
}
// This function is used when you want to create a scalar constant in a graph.
// It may create a NodeArg so that other Node can references its value.
// It also cerates an initializer to store its value.
template <typename T>
void AddNewScalarNodeArgAndInitializer(Graph& graph,
const std::string& op_name,
onnx::TensorProto_DataType type,
T data,
std::vector<NodeArg*>& new_node_args,
std::vector<std::string>& new_names) {
AddNewNodeArg(graph, op_name, type, new_node_args, new_names);
ONNX_NAMESPACE::TensorProto proto_data;
proto_data.set_name(new_names.back());
proto_data.set_data_type(type);
switch (type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
proto_data.add_int32_data(static_cast<int32_t>(data));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
proto_data.add_int64_data(static_cast<int64_t>(data));
break;
default:
ORT_THROW("pipeline partition unsupported 'type' value: ", type);
}
graph.AddInitializedTensor(proto_data);
}
// split the graph into disconnected subgraph based on provided CutInfo
common::Status SplitGraph(Graph& graph,
std::vector<TrainingSession::TrainingConfiguration::CutInfo> split_edge_groups,
std::vector<Node*>& send_nodes,
std::vector<Node*>& recv_nodes) {
std::vector<std::string> new_input_names;
std::vector<std::string> new_output_names;
// updated_node_args keeps track of the mapping between the original node_arg and its corresponding updated
// node_arg after send and recv node is added. As multiple partitions can happen, and a single node_arg
// can belong to different partition, updated_node_args always keeps track of the latest updated node_arg.
// Below is one example of how this works using update_node_args:
// there are three edges in graph, specified as nodeA->nodeB, nodeA->nodeC, and nodeA->nodeD.
// those edges all share the same node_arg.
// but nodeA, nodeB belong to parition0, nodeC belongs to parition1, and nodeD belongs to parition2.
// This means we need to cut edge nodeA->nodeC for the first partition and nodeA->nodeD for the second partition.
//
// During the first cut, we identify the edge nodeA->nodeC, for this edge, based on the origional node_arg,
// we create a new node_arg, called updated_node_arg. The inserted send node will take the original node_arg
// as input and the inserted recv node will take the updated_node_arg as the output.
// And we update updated_node_args with updated_node_args[original_node_arg] = updated_node_arg
//
// Now during the second cut, we need to cut the edge nodeA->nodeD. Noted that as the cut is performed in sequential,
// the second cut is performed based on the graph modified after the first cut. This means, the input node_arg for
// nodeD shouldn't come from nodeA anymore, as nodeA now residents in partition0, which is a disconnected partition.
// Instead, the input node_arg of nodeD should come from the updated version: updated_node_arg from partition1.
// By using the updated_node_args map, we can retrieve updated_node_arg from original_node_arg, and use that as the
// newly inserted send's input. Also, to keep this on going for any following cut, we create an updated_node_arg_v2,
// and update updated_node_args with updated_node_args[original_node_arg] = updated_node_arg_v2
std::map<NodeArg*, NodeArg*> updated_node_args;
for (size_t index = 0; index < split_edge_groups.size(); ++index) {
// each entry in split_edge_groups represents a partition cut. Each cut can contain the split of
// several edges.
auto& edgeIds = split_edge_groups[index];
// for each cut, record the inserted input/output args.
std::vector<NodeArg*> send_input_args;
std::vector<NodeArg*> send_output_args;
std::vector<NodeArg*> recv_input_args;
std::vector<NodeArg*> recv_output_args;
auto cut_index_str = std::to_string(index);
// add input node_arg and initializer for send/recv
AddNewScalarNodeArgAndInitializer<bool>(graph,
"send_input_signal" + cut_index_str,
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
true, /* initializer data */
send_input_args,
new_input_names);
AddNewScalarNodeArgAndInitializer<bool>(graph,
"recv_input_signal" + cut_index_str,
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
true, /* initializer data */
recv_input_args,
new_input_names);
AddNewScalarNodeArgAndInitializer<size_t>(graph,
"send_dst_rank" + cut_index_str,
ONNX_NAMESPACE::TensorProto_DataType_INT64,
index + 1, /* initializer data */
send_input_args,
new_input_names);
AddNewScalarNodeArgAndInitializer<size_t>(graph,
"recv_src_rank" + cut_index_str,
ONNX_NAMESPACE::TensorProto_DataType_INT64,
index, /* initializer data */
recv_input_args,
new_input_names);
// add output node_arg for send/recv
AddNewNodeArg(graph, "send_output_signal" + cut_index_str, ONNX_NAMESPACE::TensorProto_DataType_BOOL,
send_output_args, new_output_names);
AddNewNodeArg(graph, "receive_output_signal" + cut_index_str, ONNX_NAMESPACE::TensorProto_DataType_BOOL,
recv_output_args, new_output_names);
// add attribute data for send/recv
ONNX_NAMESPACE::AttributeProto tag;
tag.set_name("tag");
tag.set_type(ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INT);
// currently hard-coded all tag to be 0. May need to change when multiple GPU stream is used.
tag.set_i(static_cast<int64_t>(0));
ONNX_NAMESPACE::AttributeProto element_types;
element_types.set_name("element_types");
element_types.set_type(ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS);
// for each edge in this group, perform edge cut
for (auto& id : edgeIds) {
// find node whose output contains id.node_arg_name
auto producer_node = graph.GetMutableProducerNode(id.node_arg_name);
if (!producer_node) {
ORT_THROW("Cannot find producer node of node_arg with name: ", id.node_arg_name, ". Wrong cutting infomation.");
}
// once we find out the producer node for id.node_arg_name, find which output index that leads
// to id.node_arg_name
int upstream_nodes_output_index{-1};
producer_node->ForEachWithIndex(
producer_node->OutputDefs(),
[&](const NodeArg& def, size_t index) {
if (def.Name() == id.node_arg_name) {
upstream_nodes_output_index = static_cast<int>(index);
}
return Status::OK();
});
if (upstream_nodes_output_index < 0) {
ORT_THROW("Node with name: ", producer_node->Name(),
" doesn't have an output node_arg with name ", id.node_arg_name);
}
size_t idx = static_cast<size_t>(upstream_nodes_output_index);
// original node_arg pointer from the origin graph. This serves as the key in the
// updated_node_arg map and any reference for original node_arg name
auto* original_node_arg = producer_node->MutableOutputDefs()[idx];
// updated node_arg pointer from previous partition. This is the new arg_node the
// current inserted send node will take as input node_arg.
auto updated_node_arg = producer_node->MutableOutputDefs()[idx];
auto exiting_updated_node_arg = updated_node_args.find(original_node_arg);
if (exiting_updated_node_arg != updated_node_args.end()) {
updated_node_arg = exiting_updated_node_arg->second;
}
send_input_args.push_back(updated_node_arg);
auto dtype = original_node_arg->TypeAsProto()->tensor_type().elem_type();
switch (dtype) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
element_types.add_ints(static_cast<int64_t>(1));
break;
default:
// Assume all tensors are of type float.
// TODO: update if graph supports other data type.
ORT_THROW("pipeline partition unsupported 'type' value: ", dtype);
}
auto& new_receive_output = CreateNodeArg(graph, updated_node_arg);
const auto old_shape = *(updated_node_arg->Shape());
new_receive_output.SetShape(old_shape);
recv_output_args.push_back(&new_receive_output);
// add value info for this newly added receive_output, for shape propagation
// when training this partition.
graph.AddValueInfo(&new_receive_output);
// update updated_node_args with the newly created node_arg
updated_node_args[original_node_arg] = &new_receive_output;
// deal with shape inference for newly added edge
auto& output_edge_name = original_node_arg->Name();
// deal with updating the consumer's input node_args
std::vector<Node*> consumer_nodes;
if (id.consumer_nodes.has_value()) {
for(auto& consumer_node_id : id.consumer_nodes.value()){
consumer_nodes.push_back(graph.GetMutableProducerNode(consumer_node_id));
}
} else {
consumer_nodes = graph.GetMutableConsumerNodes(output_edge_name);
}
for (auto consumer_node : consumer_nodes) {
for (auto& input : consumer_node->MutableInputDefs()) {
if (input->Name() == output_edge_name) {
input = &new_receive_output;
break;
}
}
}
}
};
const int num_attributes = 2; // two attributes: tag and element_types
NodeAttributes attributes;
attributes.reserve(num_attributes);
attributes[tag.name()] = tag;
attributes[element_types.name()] = element_types;
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();
std::vector<const NodeArg*> inputs_args_sets;
inputs_args_sets.reserve(graph_inputs.size() + new_input_names.size());
fill_node_args(graph, graph_inputs, new_input_names, inputs_args_sets);
auto& send_node = graph.AddNode(graph.GenerateNodeName("Send"),
"Send",
"",
send_input_args,
send_output_args, /* output */
&attributes, /* attribute */
kMSDomain);
const std::vector<const NodeArg*>& graph_outputs = graph.GetOutputs();
std::vector<const NodeArg*> outputs_args_sets;
outputs_args_sets.reserve(graph_outputs.size() + new_output_names.size());
fill_node_args(graph, graph_outputs, new_output_names, outputs_args_sets);
send_nodes.push_back(&send_node);
graph.SetInputs(inputs_args_sets);
graph.SetOutputs(outputs_args_sets);
auto& recv_node = graph.AddNode(graph.GenerateNodeName("Recv"),
"Recv",
"",
recv_input_args,
recv_output_args, /* output */
&attributes, /* attribute */
kMSDomain);
recv_nodes.push_back(&recv_node);
}
ORT_RETURN_IF_ERROR(SetInputsOutputsAndResolve(graph, new_input_names, new_output_names));
return Status::OK();
}
Status FindAllConnectedNodes(Graph& graph,
const Node* node,
std::vector<const Node*>& connected_nodes,
std::set<const NodeArg*>& connected_inputs,
std::set<const NodeArg*>& connected_outputs) {
ORT_THROW_IF_ERROR(node->ForEachWithIndex(
node->InputDefs(),
[&](const NodeArg& node_arg, size_t /*index*/) {
if (graph.IsInputsIncludingInitializers(&node_arg)) {
connected_inputs.insert(&node_arg);
} else {
const Node* producer_node = graph.GetProducerNode(node_arg.Name());
if (producer_node == nullptr) {
// got nullptr as producer node. This could be because the input is a constant op which will be optimized
// away. Print out this information and continue.
LOGS_DEFAULT(WARNING) << "Cannot find producer node for node_arg: " << node_arg.Name() << ". Skipping this node.";
} else {
connected_nodes.push_back(producer_node);
}
}
return Status::OK();
}));
ORT_THROW_IF_ERROR(node->ForEachWithIndex(
node->OutputDefs(),
[&](const NodeArg& node_arg, size_t /*index*/) {
if (!graph.IsOutput(&node_arg)) {
std::vector<const Node*> consumer_nodes = graph.GetConsumerNodes(node_arg.Name());
connected_nodes.insert(std::end(connected_nodes), consumer_nodes.begin(), consumer_nodes.end());
} else {
connected_outputs.insert(&node_arg);
}
return Status::OK();
}));
return Status::OK();
}
// traverse the graph from start_node to get the set of nodes contains in this disconnected subgraph
common::Status GenerateSubgraph(Graph& graph, const Node* start_node) {
std::queue<const Node*> node_queue;
node_queue.push(start_node);
std::set<const Node*> visited_nodes;
std::set<const NodeArg*> visited_inputs;
std::set<const NodeArg*> visited_outputs;
// BFS graph traverse
while (!node_queue.empty()) {
auto node = node_queue.front();
node_queue.pop();
if (visited_nodes.count(node) == 0) {
visited_nodes.insert(node);
std::vector<const Node*> connected_nodes;
ORT_THROW_IF_ERROR(FindAllConnectedNodes(graph, node, connected_nodes, visited_inputs, visited_outputs));
for (auto n : connected_nodes) {
ORT_ENFORCE(n!=nullptr, "Found nullptr in searching for connected nodes");
node_queue.push(n);
}
}
}
std::set<NodeIndex> visited_node_index;
for (auto n : visited_nodes) {
visited_node_index.insert(n->Index());
}
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
// reverse iterate the nodes in tolopogical order, and delete those not visited
for (auto it = node_topology_list.rbegin(); it != node_topology_list.rend(); it++){
if (visited_node_index.count(*it)==0){
graph.RemoveNode(*it);
}
}
// update the grah with only visited inputs and outputs
graph.SetInputs({visited_inputs.begin(), visited_inputs.end()});
graph.SetOutputs({visited_outputs.begin(), visited_outputs.end()});
graph.SetGraphResolveNeeded();
graph.SetGraphProtoSyncNeeded();
return graph.Resolve();
}
Status ApplyPipelinePartitionToMainGraph(
Graph& graph,
const std::vector<TrainingSession::TrainingConfiguration::CutInfo>& cut_info,
size_t pipeline_stage_id,
size_t num_pipeline_stage) {
size_t split_count = cut_info.size();
if (num_pipeline_stage != split_count + 1) {
ORT_THROW("Wrong pipeline partition cutting info. Total pipeline stage number is ",
num_pipeline_stage,
", cut info length is: ",
split_count);
}
std::vector<Node *> send_nodes, recv_nodes;
send_nodes.reserve(split_count);
recv_nodes.reserve(split_count);
ORT_RETURN_IF_ERROR(SplitGraph(graph, cut_info, send_nodes, recv_nodes));
if (send_nodes.size() != split_count || recv_nodes.size() != split_count) {
ORT_THROW("Split error: not all cut has Send and Recv inserted. Send node count: ",
send_nodes.size(), ", Recv node count: ", recv_nodes.size(), ", split count: ", split_count);
}
if (pipeline_stage_id < split_count) {
ORT_RETURN_IF_ERROR(GenerateSubgraph(graph, send_nodes[pipeline_stage_id]));
} else {
ORT_RETURN_IF_ERROR(GenerateSubgraph(graph, recv_nodes.back()));
}
return Status::OK();
}
} // namespace training
} // namespace onnxruntime

View file

@ -4,12 +4,14 @@
#pragma once
#include "core/graph/graph.h"
#include "orttraining/core/session/training_session.h"
namespace onnxruntime {
namespace training {
void GetPipelineSendOutput(const Graph& graph, std::string& loss_name);
common::Status TransformGraphForPipeline(
Status TransformGraphForPipeline(
Graph& graph,
std::string& forward_waited_event_name,
std::string& forward_recorded_event_name,
@ -24,5 +26,10 @@ common::Status TransformGraphForPipeline(
std::string& backward_waited_event_after_recv_name,
std::string& backward_recorded_event_before_send_name);
Status ApplyPipelinePartitionToMainGraph(
Graph& graph,
const std::vector<TrainingSession::TrainingConfiguration::CutInfo>& cut_info,
size_t pipeline_stage_id,
size_t num_pipeline_stage);
} // namespace training
} // namespace onnxruntime

View file

@ -141,6 +141,15 @@ Status TrainingSession::ConfigureForTraining(
excluded_initializers.erase(weight_name_to_not_train);
}
if (config.pipeline_config.has_value() && config.pipeline_config.value().do_partition) {
// Apply online pipeline partition to graph obj. This needs to be done first before any graph
// transportation which may alter node_arg and invalidate cut_list info from the original graph.
ORT_RETURN_IF_ERROR(ApplyPipelinePartitionToMainGraph(model_->MainGraph(),
config.pipeline_config.value().cut_list,
config.distributed_config.world_rank,
config.distributed_config.world_size));
}
ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(excluded_initializers));
is_mixed_precision_enabled_ = config.mixed_precision_config.has_value();
@ -233,7 +242,7 @@ Status TrainingSession::ConfigureForTraining(
}
pipeline_result.fetch_names.push_back(name);
}
pipeline_result.pipeline_stage_id = config.distributed_config.world_rank /
pipeline_result.pipeline_stage_id = config.distributed_config.world_rank /
(config.distributed_config.data_parallel_size * config.distributed_config.horizontal_parallel_size);
config_result.pipeline_config_result = pipeline_result;
}
@ -307,13 +316,19 @@ Status TrainingSession::ConfigureForTraining(
tensorboard_config.histogram_node_names, tensorboard_config.norm_node_names,
tensorboard_config.dump_convergence_metrics));
}
// add GIST encoding
if (config.gist_config.has_value()) {
ORT_RETURN_IF_ERROR(AddGistEncoding());
}
if (IsRootNode(config) && config.model_with_training_graph_path.has_value()) {
// If the current node is in rank0 or if the current session is running pipeline (in which case different rank would
// store different model partition), and if model_with_training_graph_path is specified, save the model.
// Note: in the pipeline case, different ranks may resident in the same node. This could lead to a potential write
// conflict. It is user's responsibility to make sure different rank is passed in with different
// model_with_training_graph_path value.
if ((IsRootNode(config) || config.pipeline_config.has_value())
&& config.model_with_training_graph_path.has_value()) {
ORT_IGNORE_RETURN_VALUE(Save(
config.model_with_training_graph_path.value(), SaveOption::NO_RELOAD));
}

View file

@ -135,6 +135,28 @@ class TrainingSession : public InferenceSession {
// If not provided, no optimizer is added.
optional<OptimizerConfiguration> optimizer_config{};
// struct to describe a specific edge. An edge is not the same as a node_arg. Edge represents a connection between two operators.
// For example, an operator A's output tensor T is connecting to another operator B's input, then this constructs
// an edge from A to B. If A's output tensor T has multiple consumers, i.e. it's fed into multiple operators' inputs,
// there would be multiple edges, each from A, to a consumer operator.
// CutEdge information is used in pipeline online partition tool to identify which edge to cut to make the
// corresponding partition.
struct CutEdge {
std::string node_arg_name;
optional<std::vector<std::string>> consumer_nodes;
// If the edge is unique, i.e. only have one consumer node, or all the edges
// with the same node_arg_name needs to be cut, specify the node_arg_name
// suffices.
CutEdge(std::string edge) : node_arg_name(edge){};
// If the edges with same node_arg_name belongs to different cut, i.e. some of its
// consumer node belongs to one partition, and some belongs to another, specify
// the consumer node names which you want to perform the cut on.
CutEdge(std::string edge, std::vector<std::string> nodes) : node_arg_name(edge), consumer_nodes(nodes){};
};
// CutInfo is a group of CutEdges that describes a specific cut that composed of splitting those edges.
typedef std::vector<CutEdge> CutInfo;
struct PipelineConfiguration {
// If model partition happens outside ORT, this flag should be false.
// Otherwise, use true to trigger ORT's pipeline partition.
@ -142,7 +164,9 @@ class TrainingSession : public InferenceSession {
// Tensors to fetch as specified by the user.
// Each pipeline stage should pick up some strings from this field..
std::vector<std::string> fetch_names;
// [TODO] Add cut information.
// cut_list contains the list of CutInfo to make the graph partitions.
// cut_list[i] contains the CutInfo to make the partition between stage i and stage i+1
std::vector<CutInfo> cut_list;
};
// If pipeline is enabled, this field's has_value() returns true.
@ -329,7 +353,7 @@ class TrainingSession : public InferenceSession {
// Insert operators for running pipeline and return event tensor names.
// For an intermediate pipeline stage, its original computation is
//
// Recv --> Forward --> Send -->
// Recv --> Forward --> Send -->
// Recv --> Backward --> Send
//
// After this function, the resulted computation is
@ -340,7 +364,7 @@ class TrainingSession : public InferenceSession {
// As you can see, some event operators are inserted. For each event operator, its dependent
// event tensor name is written to an input references, for example, "forward_waited_event_name".
//
// This function assumes that
// This function assumes that
// 1. Only one Recv and only one Send present in forward pass.
// 2. Only one Recv and only one Send present in backward pass.
// 3. Backward operators' descriptions are all "Backward pass". This assumption is used to

View file

@ -155,6 +155,11 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet
("horizontal_parallel_size", "Horizontal model parallel group size.", cxxopts::value<int>()->default_value("1"))
("pipeline_parallel_size", "Number of pipeline stages.", cxxopts::value<int>()->default_value("1"))
("pipeline_stage_paths", "Specify the forward ONNX files for pipeline evaluation.", cxxopts::value<std::vector<std::string>>()->default_value(""))
("cut_group_info", "Specify the cutting info for graph partition (pipeline only). An example of a cut_group_info of "
"size two is: 1393:407-1463/1585/1707,2369:407-2439/2561/2683. Here, the cut info is split by ',', with the first "
"cut_info equal to 1393:407-1463/1585/1707, and second cut_info equal to 2369:407-2439/2561/2683. Each CutEdge is "
"seperated by ':'. If consumer nodes need to be specified, specify them after producer node with a '-' delimiter and "
"separate each consumer node with a '/'. ", cxxopts::value<std::vector<std::string>>()->default_value(""))
("enable_grad_norm_clip", "Specify whether to enable gradient clipping for optimizers.",
cxxopts::value<bool>()->default_value("true"));
options
@ -379,6 +384,58 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet
// Backward pass and optimizer nodes are implicitly generated by ORT.
params.pipeline_stage_paths = flags["pipeline_stage_paths"].as<std::vector<std::string>>();
// If user doesn't provide partitioned model files, a cut list should be provided for ORT to do partition
// online. If the pipeline contains n stages, the cut list should be of length (n-1), in order to cut the
// graph into n partitions.
if (params.pipeline_parallel_size > 1 && params.pipeline_stage_paths.empty()) {
auto cut_info_groups = flags["cut_group_info"].as<std::vector<std::string>>();
ORT_RETURN_IF_NOT(static_cast<int>(cut_info_groups.size() + 1) == params.pipeline_parallel_size,
"cut_info length plus one must match pipeline parallel size");
auto process_with_delimiter = [](std::string& input_str, const std::string& delimiter) {
std::vector<std::string> result;
size_t pos = 0;
std::string token;
while ((pos = input_str.find(delimiter)) != std::string::npos) {
token = input_str.substr(0, pos);
result.emplace_back(token);
input_str.erase(0, pos + delimiter.length());
}
// push the last split of substring into result.
result.emplace_back(input_str);
return result;
};
auto process_cut_info = [&](std::string& cut_info_string) {
TrainingSession::TrainingConfiguration::CutInfo cut_info;
const std::string edge_delimiter = ":";
const std::string consumer_delimiter = "/";
const std::string producer_consumer_delimiter = "-";
auto cut_edges = process_with_delimiter(cut_info_string, edge_delimiter);
for (auto& cut_edge : cut_edges) {
auto process_edge = process_with_delimiter(cut_edge, producer_consumer_delimiter);
if (process_edge.size() == 1) {
TrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0]};
cut_info.emplace_back(edge);
} else {
ORT_ENFORCE(process_edge.size() == 2);
auto consumer_list = process_with_delimiter(process_edge[1], consumer_delimiter);
TrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0], consumer_list};
cut_info.emplace_back(edge);
}
}
return cut_info;
};
for (auto& cut_info : cut_info_groups) {
TrainingSession::TrainingConfiguration::CutInfo cut = process_cut_info(cut_info);
params.pipeline_partition_cut_list.emplace_back(cut);
}
}
int64_t seed = flags["seed"].as<int64_t>();
if (params.horizontal_parallel_size > 1 && seed <= 0) {
seed = 8211; // Megatron needs a random seed.

View file

@ -158,6 +158,7 @@ Status TrainingRunner::Initialize() {
// the session already loads a pipeline stage.
pipe.do_partition = params_.pipeline_stage_paths.empty() ? true : false;
pipe.fetch_names = params_.fetch_names;
pipe.cut_list = params_.pipeline_partition_cut_list;
// Do not assign value to config.pipeline_config if pipeline is not used.
config.pipeline_config = pipe;
}
@ -321,7 +322,7 @@ Status TrainingRunner::Initialize() {
return Status::OK();
}
Status TrainingRunner::Run(IDataLoader* training_data_loader, IDataLoader* test_data_loader,
Status TrainingRunner::Run(IDataLoader* training_data_loader, IDataLoader* test_data_loader,
const MapStringToString& mapped_dimensions) {
if (params_.mpi_context.world_rank == 0 && !params_.model_actual_running_graph_path.empty()) {
session_.Save(params_.model_actual_running_graph_path, TrainingSession::SaveOption::NO_RELOAD);
@ -528,7 +529,7 @@ Status TrainingRunner::PrepareFetchNamesAndFetches(const SessionMode mode,
const auto& allowed_fetch_names = pipeline_context_.fetch_names;
if (mode == ModelUpdateStep) {
// Set up tensor to be fetched when doing model update.
// Set up tensor to be fetched when doing model update.
if (params_.pipeline_parallel_size > 1) {
// If pipeline is used, we need to filter out fetches which are not in this pipeline stage.
@ -557,7 +558,7 @@ Status TrainingRunner::PrepareFetchNamesAndFetches(const SessionMode mode,
}
}
} else if (mode == GradientAccumulateStep) {
// Set up tensor to be fetched when doing gradient accumulation.
// Set up tensor to be fetched when doing gradient accumulation.
if (params_.gradient_accumulation_steps > 1) {
auto it = opt_graph_outputs_.find(OptimizerOutputKey::GradientAccumulation);
@ -583,7 +584,7 @@ Status TrainingRunner::PrepareFetchNamesAndFetches(const SessionMode mode,
}
} else if (mode == EvaluateStep) {
// Set up tensor to be fetched when doing model evaluation.
// Ideally, this path should not fetch optimizer and gradient accumulation.
// Ideally, this path should not fetch optimizer and gradient accumulation.
// This path may fetch predicted scores, loss value, and so on.
if (params_.pipeline_parallel_size > 1) {
@ -658,7 +659,7 @@ Status TrainingRunner::RunWithUpdate(VectorString& feed_names,
}
}
// Wait all workers to finish this around of pipeline parallism.
// Wait all workers to finish this around of pipeline parallism.
// The last batch in a pipeline collects gradient and update the model.
pipeline_worker_pool_.JoinAll();
@ -802,7 +803,8 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad
fetch_names,
fetches);
RunWithoutUpdate(feed_names, fetch_names, feeds,
gradient_accumulation_step_count);
gradient_accumulation_step_count);
}
auto end = std::chrono::high_resolution_clock::now();
@ -880,7 +882,7 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad
const size_t peak_workingset_size = perftest::utils::GetPeakWorkingSetSize();
ORT_RETURN_IF_ERROR(Env::Default().CreateFolder(params_.perf_output_dir));
// saving json file
ORT_RETURN_IF_ERROR(SavePerfMetrics(number_of_batches, gradient_accumulation_step_count, weight_update_steps,
ORT_RETURN_IF_ERROR(SavePerfMetrics(number_of_batches, gradient_accumulation_step_count, weight_update_steps,
total_time, avg_time_per_batch, throughput, stabilized_throughput, mapped_dimensions,
average_cpu_usage, peak_workingset_size));
}
@ -901,17 +903,17 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad
Status TrainingRunner::SavePerfMetrics(const size_t number_of_batches, const size_t gradient_accumulation_steps,
const size_t weight_update_steps, const double total_time,
const double avg_time_per_batch, const double throughput, const double stabilized_throughput,
const MapStringToString& mapped_dimensions,
const MapStringToString& mapped_dimensions,
const short average_cpu_usage, const size_t peak_workingset_size) {
// populate metrics for reporting
json perf_metrics;
perf_metrics["Model"] = params_.model_type;
perf_metrics["Model"] = params_.model_type;
// loop thru the mapped_dimensions and put it in json sub-structure
std::string seq_len;
for (auto const& it : mapped_dimensions) {
if (it.first == "SeqLen") {
seq_len = it.second;
seq_len = it.second;
}
perf_metrics["DerivedProperties"][it.first] = it.second;
}
@ -929,7 +931,7 @@ Status TrainingRunner::SavePerfMetrics(const size_t number_of_batches, const siz
std::string optimizer = params_.training_optimizer_name;
std::size_t pos = optimizer.find("Optimizer");
if (pos != std::string::npos)
if (pos != std::string::npos)
optimizer = optimizer.substr(0, pos);
perf_metrics["Optimizer"] = optimizer;
@ -948,7 +950,7 @@ Status TrainingRunner::SavePerfMetrics(const size_t number_of_batches, const siz
//
// we will get date/time and commitId in post-run pipeline
//
//
// populate other basic params for bookkeeping - add more as needed
json bookkeeping_params;
@ -962,7 +964,7 @@ Status TrainingRunner::SavePerfMetrics(const size_t number_of_batches, const siz
perf_metrics["RunConfig"] = bookkeeping_params.dump(); // serialize the params as json string
std::string json_string = perf_metrics.dump();
std::string json_string = perf_metrics.dump();
// write to a file - the next task in CI will pick up all files with the same prefix
const PathString perf_metrics_path =
@ -1072,7 +1074,7 @@ Status TrainingRunner::Evaluate(InferenceSession& session, IDataLoader& data_loa
fetch_names,
&fetches));
// Assume that user-specified fetches are avaliable only on the last pipeline stage.
// Assume that user-specified fetches are avaliable only on the last pipeline stage.
// When there is no pipeline, all pipeline_context_.pipeline_stage_id should be 0 and
// params_.pipeline_parallel_size is 1. Thus, the following condition is always true if there
// is no pipeline.

View file

@ -156,8 +156,11 @@ class TrainingRunner {
// pipeline_parallel_size > 1 means pipeline is enabled.
// pipeline_parallel_size == 1 means pipeline is disabled.
int pipeline_parallel_size = 1;
// pipeline partition information to do online-partition. If the graph is
// pre-partitioned, no need to fill this value.
std::vector<TrainingSession::TrainingConfiguration::CutInfo> pipeline_partition_cut_list;
// model_paths[i] is the name of the pipeline stage for i-th process.
// The i-th file is run by the i-th MPI rank.
// The i-th file is run by the i-th MPI rank.
// If model_paths is not empty, model partition transformation may not be internally invoked.
VectorString pipeline_stage_paths;
// Enable gradient clipping.
@ -169,7 +172,7 @@ class TrainingRunner {
common::Status Initialize();
common::Status Run(IDataLoader* training_data_loader, IDataLoader* test_data_loader,
common::Status Run(IDataLoader* training_data_loader, IDataLoader* test_data_loader,
const MapStringToString& mapped_dimensions = {});
common::Status EndTraining(IDataLoader* data_loader);
@ -196,12 +199,12 @@ class TrainingRunner {
Status RunWithUpdate(VectorString& feed_names,
VectorString& fetch_names,
std::vector<MLValue>& feeds,
std::vector<MLValue>& fetches);
std::vector<MLValue>& fetches);
Status RunWithoutUpdate(VectorString& feed_names,
VectorString& fetch_names,
std::vector<MLValue>& feeds,
size_t& gradient_accumulation_step_count);
Status TrainingLoop(IDataLoader& training_data_loader, IDataLoader* test_data_loader,
size_t& gradient_accumulation_step_count);
Status TrainingLoop(IDataLoader& training_data_loader, IDataLoader* test_data_loader,
const MapStringToString& mapped_dimensions);
Status Evaluate(InferenceSession& session, IDataLoader& data_loader);
@ -230,7 +233,7 @@ class TrainingRunner {
AllocatorPtr input_allocator_;
std::unique_ptr<CheckpointRegistry> checkpoint_registry_;
// Pipeline fields are valid only if params_.pipeline_parallel_size > 1.
// Information for running pipeline.
pipeline::PipelineContext pipeline_context_;

View file

@ -998,7 +998,7 @@ class PipelineBatchPlanner {
};
void RetrieveEventOperators(
Graph& graph,
Graph& graph,
Node** forward_wait_before_recv,
Node** forward_wait_after_recv,
Node** forward_record_before_send,
@ -1065,7 +1065,7 @@ void RetrieveEventOperators(
}
void RetrieveSendRecvOperators(
Graph& graph,
Graph& graph,
Node** forward_recv,
Node** forward_send,
Node** backward_recv,
@ -1089,9 +1089,9 @@ void RetrieveSendRecvOperators(
if (is_backward(node)) {
// backward_send can only be assigned one valid pointer.
// If it is assigned more than once, it means we have multiple
// Send in backward pass and therefore our assumption doesn't hold.
// Send in backward pass and therefore our assumption doesn't hold.
// This check ensure that only we only update *backward_send when
// its value is NULL and guards our one-Recv assumption.
// its value is NULL and guards our one-Recv assumption.
ASSERT_TRUE(!(*backward_send));
*backward_send = &node;
} else {
@ -1115,6 +1115,46 @@ void RetrieveSendRecvOperators(
}
}
TEST(GradientGraphBuilderTest, PipelineOnlinePartition) {
auto model_uri = ORIGINAL_MODEL_PATH;
TrainingSession::TrainingConfiguration::PipelineConfiguration pipe{};
pipe.do_partition = true;
// evenly cut the MLP model in 3 partitions
TrainingSession::TrainingConfiguration::CutInfo cut0 = {TrainingSession::TrainingConfiguration::CutEdge("T3")};
TrainingSession::TrainingConfiguration::CutInfo cut1 = {TrainingSession::TrainingConfiguration::CutEdge("T6")};
pipe.cut_list.emplace_back(cut0);
pipe.cut_list.emplace_back(cut1);
for (int i = 0; i < 3; ++i) {
#ifdef _WIN32
auto surfix = std::to_wstring(i);
#else
auto surfix = std::to_string(i);
#endif
PathString output_file = ORT_TSTR("pipeline_partition_") + surfix + ORT_TSTR("_back.onnx");
auto config = MakeBasicTrainingConfig();
config.pipeline_config = pipe;
config.distributed_config.world_rank = i;
config.distributed_config.world_size = 3;
config.distributed_config.local_rank = i;
config.distributed_config.local_size = 3;
config.distributed_config.data_parallel_size = 1;
config.distributed_config.horizontal_parallel_size = 1;
config.distributed_config.pipeline_parallel_size = 3;
config.model_with_training_graph_path = output_file;
PathString backprop_model_file;
ASSERT_STATUS_OK(BuildBackPropGraph(model_uri, config, backprop_model_file));
std::shared_ptr<Model> model;
// Ensure the partitioned model load.
ASSERT_STATUS_OK(Model::Load(backprop_model_file, model, nullptr, DefaultLoggingManager().DefaultLogger()));
}
}
// verify pipeline config can load and gradient graph can construct.
TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) {
PathString filename_base = ORT_TSTR("testdata/test_training_model_");
@ -1197,7 +1237,7 @@ TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) {
Node* backward_send{nullptr};
RetrieveSendRecvOperators(
graph,
graph,
&forward_recv,
&forward_send,
&backward_recv,