mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Add pipeline online partition logic for pipeline (#3996)
* online partition * fix when multiple consumer nodes is in cut info * fix windows build * address feedback * adding test * feedback * address feedback * add parser for cut edge * windows build
This commit is contained in:
parent
0d8abc1a99
commit
633008b5ef
11 changed files with 612 additions and 77 deletions
|
|
@ -525,6 +525,11 @@ class Graph {
|
|||
return graph_inputs_including_initializers_;
|
||||
}
|
||||
|
||||
/** Return true if "node_arg" is a input or an initializer. Otherwise, returns false. */
|
||||
bool IsInputsIncludingInitializers(const NodeArg* node_arg) const noexcept{
|
||||
return std::find(graph_inputs_including_initializers_.begin(), graph_inputs_including_initializers_.end(), node_arg) != graph_inputs_including_initializers_.end();
|
||||
}
|
||||
|
||||
/** Gets the Graph inputs that are initializers
|
||||
These are overridable initializers. This is a difference between
|
||||
graph_inputs_including_initializers_ and graph_inputs_excluding_initializers_
|
||||
|
|
@ -537,6 +542,10 @@ class Graph {
|
|||
@remarks Contains no nullptr values.*/
|
||||
const std::vector<const NodeArg*>& GetOutputs() const noexcept { return graph_outputs_; }
|
||||
|
||||
bool IsOutput(const NodeArg* node_arg) const noexcept{
|
||||
return std::find(graph_outputs_.begin(), graph_outputs_.end(), node_arg) != graph_outputs_.end();
|
||||
}
|
||||
|
||||
/** Returns a vector with the indexes of the outputs of the given Node that are also Graph outputs. */
|
||||
std::vector<int> GetNodeOutputsInGraphOutputs(const Node& node) const {
|
||||
int output_idx = 0;
|
||||
|
|
@ -557,6 +566,8 @@ class Graph {
|
|||
@remarks Contains no nullptr values. */
|
||||
const std::vector<const NodeArg*>& GetValueInfo() const noexcept;
|
||||
|
||||
void AddValueInfo(const NodeArg* new_value_info);
|
||||
|
||||
/** Gets the Node with the specified node index.
|
||||
@returns Node instance if found. nullptr if node_index is invalid or node has been freed.
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ namespace onnxruntime {
|
|||
class Node;
|
||||
|
||||
/**
|
||||
Class to filter out null entries from either a vector of unique_ptr<Node> or a vector of [const] Node* and
|
||||
Class to filter out null entries from either a vector of unique_ptr<Node> or a vector of [const] Node* and
|
||||
provide an iterator interface that returns [const] Node& for the valid entries.
|
||||
*/
|
||||
template <typename TNodesContainer>
|
||||
|
|
@ -29,6 +29,7 @@ class ValidNodes {
|
|||
|
||||
using ConstNodeIterator = NodeIterator<typename TNodesContainer::const_iterator>;
|
||||
using MutableNodeIterator = NodeIterator<typename TNodesContainer::iterator>;
|
||||
using ConstReverseNodeIterator = NodeIterator<typename TNodesContainer::const_reverse_iterator>;
|
||||
|
||||
ConstNodeIterator cbegin() const noexcept {
|
||||
return {nodes_.cbegin(), nodes_.cend()};
|
||||
|
|
@ -46,6 +47,14 @@ class ValidNodes {
|
|||
return cend();
|
||||
}
|
||||
|
||||
ConstReverseNodeIterator rbegin() const noexcept {
|
||||
return {nodes_.crbegin(), nodes_.crend()};
|
||||
}
|
||||
|
||||
ConstReverseNodeIterator rend() const noexcept {
|
||||
return {nodes_.crend(), nodes_.crend()};
|
||||
}
|
||||
|
||||
MutableNodeIterator begin() noexcept {
|
||||
return {nodes_.begin(), nodes_.end()};
|
||||
}
|
||||
|
|
@ -56,10 +65,10 @@ class ValidNodes {
|
|||
|
||||
bool empty() const noexcept { return nodes_.empty(); }
|
||||
|
||||
/**
|
||||
/**
|
||||
@class NodeIterator
|
||||
Iterator to provide const and non-const access to valid Node instances in a Graph.
|
||||
@remarks Skips invalid nodes.
|
||||
@remarks Skips invalid nodes.
|
||||
*/
|
||||
template <typename TIterator>
|
||||
class NodeIterator {
|
||||
|
|
@ -130,7 +139,7 @@ class ValidNodes {
|
|||
};
|
||||
|
||||
/**
|
||||
Class that provides iteration over all valid nodes in the Graph.
|
||||
Class that provides iteration over all valid nodes in the Graph.
|
||||
*/
|
||||
class GraphNodes : public ValidNodes<std::vector<std::unique_ptr<Node>>> {
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -2337,6 +2337,13 @@ const std::vector<const NodeArg*>& Graph::GetValueInfo() const noexcept {
|
|||
return value_info_;
|
||||
}
|
||||
|
||||
void Graph::AddValueInfo(const NodeArg* new_value_info){
|
||||
for(const auto* info : value_info_){
|
||||
ORT_ENFORCE(info->Name() != new_value_info->Name(), "Error: trying to add an existing value info.");
|
||||
}
|
||||
value_info_.push_back(new_value_info);
|
||||
}
|
||||
|
||||
std::vector<NodeArg*> Graph::CreateNodeArgs(const google::protobuf::RepeatedPtrField<std::string>& names,
|
||||
const ArgNameToTypeMap& name_to_type_map) {
|
||||
const auto name_to_type_map_end = name_to_type_map.end();
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/core/graph/pipeline_transformer.h"
|
||||
#include <queue>
|
||||
|
||||
using namespace onnxruntime::common;
|
||||
|
||||
|
|
@ -22,21 +23,22 @@ bool IsBackward(Node& node) {
|
|||
return (node.Description() == "Backward pass");
|
||||
}
|
||||
|
||||
NodeArg& CreateInt64NodeArg(Graph& graph, const std::string& name) {
|
||||
NodeArg& CreateTypedNodeArg(Graph& graph, onnx::TensorProto_DataType type, const std::string& name) {
|
||||
ONNX_NAMESPACE::TypeProto type_proto;
|
||||
type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
type_proto.mutable_tensor_type()->set_elem_type(type);
|
||||
auto actual_name = graph.GenerateNodeArgName(name);
|
||||
auto& node_arg = graph.GetOrCreateNodeArg(actual_name, &type_proto);
|
||||
return node_arg;
|
||||
}
|
||||
|
||||
void AddInputEvent(Graph& graph,
|
||||
const std::string& event_name,
|
||||
std::vector<NodeArg*>& input_args,
|
||||
std::vector<std::string>& new_input_names) {
|
||||
auto& event_id = CreateInt64NodeArg(graph, event_name);
|
||||
new_input_names.push_back(event_id.Name());
|
||||
input_args.push_back(&event_id);
|
||||
void AddNewNodeArg(Graph& graph,
|
||||
const std::string& op_name,
|
||||
onnx::TensorProto_DataType type,
|
||||
std::vector<NodeArg*>& new_node_args,
|
||||
std::vector<std::string>& new_names) {
|
||||
auto& new_node_arg = CreateTypedNodeArg(graph, type, op_name);
|
||||
new_names.push_back(new_node_arg.Name());
|
||||
new_node_args.push_back(&new_node_arg);
|
||||
}
|
||||
|
||||
// gradient graph can contain some dangling leaf nodes. Add them all to WaitEvent
|
||||
|
|
@ -112,7 +114,7 @@ Node& CreateBottleneckNode(Graph& graph,
|
|||
nullptr /* assume all bottleneck node have no attributes */,
|
||||
kMSDomain);
|
||||
}
|
||||
|
||||
|
||||
Node* AddBackwardRecord(Graph& graph,
|
||||
Node* backward_send,
|
||||
std::vector<std::string>& new_input_names,
|
||||
|
|
@ -120,7 +122,8 @@ Node* AddBackwardRecord(Graph& graph,
|
|||
std::string &event_id_tensor_name,
|
||||
std::string &output_tensor_name) {
|
||||
std::vector<NodeArg*> input_args;
|
||||
AddInputEvent(graph, "backward_recorded_event_id", input_args, new_input_names);
|
||||
AddNewNodeArg(graph, "backward_recorded_event_id", ONNX_NAMESPACE::TensorProto_DataType_INT64,
|
||||
input_args, new_input_names);
|
||||
std::vector<NodeArg*> output_args{};
|
||||
|
||||
if (backward_send) {
|
||||
|
|
@ -175,10 +178,11 @@ Node* AddForwardWait(Graph& graph,
|
|||
|
||||
std::vector<NodeArg*> input_args;
|
||||
std::vector<NodeArg*> output_args;
|
||||
AddInputEvent(graph, "forward_waited_event_id", input_args, new_input_names);
|
||||
AddNewNodeArg(graph, "forward_waited_event_id", ONNX_NAMESPACE::TensorProto_DataType_INT64,
|
||||
input_args, new_input_names);
|
||||
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();
|
||||
|
||||
if (graph_inputs.size() == 0){
|
||||
if (graph_inputs.size() == 0) {
|
||||
ORT_THROW("Graph ", graph.Name(), " doesn't have any inputs.");
|
||||
}
|
||||
|
||||
|
|
@ -237,7 +241,8 @@ Status AddOrSkipForwardRecordBackwardWait(Graph& graph,
|
|||
{
|
||||
std::vector<NodeArg*> input_args;
|
||||
std::vector<NodeArg*> output_args;
|
||||
AddInputEvent(graph, "forward_recorded_event_id", input_args, new_input_names);
|
||||
AddNewNodeArg(graph, "forward_recorded_event_id", ONNX_NAMESPACE::TensorProto_DataType_INT64,
|
||||
input_args, new_input_names);
|
||||
|
||||
// Add send forward op's output as record op's input and output
|
||||
for (auto& output : forward_send->MutableOutputDefs()) {
|
||||
|
|
@ -258,7 +263,8 @@ Status AddOrSkipForwardRecordBackwardWait(Graph& graph,
|
|||
{
|
||||
std::vector<NodeArg*> input_args;
|
||||
std::vector<NodeArg*> output_args;
|
||||
AddInputEvent(graph, "backward_waited_event_id", input_args, new_input_names);
|
||||
AddNewNodeArg(graph, "backward_waited_event_id", ONNX_NAMESPACE::TensorProto_DataType_INT64,
|
||||
input_args, new_input_names);
|
||||
|
||||
input_args.insert(std::end(input_args),
|
||||
std::begin(record_node->MutableOutputDefs()),
|
||||
|
|
@ -287,7 +293,6 @@ void ReplaceNodeArgs(std::vector<Node*>& nodes,
|
|||
ORT_ENFORCE(node_args.size() == new_node_args.size());
|
||||
for (size_t i = 0; i < node_args.size(); ++i) {
|
||||
// Iteration for node_args[i] and new_node_args[i].
|
||||
|
||||
ORT_ENFORCE(node_args[i]->Name() != new_node_args[i]->Name());
|
||||
ORT_ENFORCE(node_args[i]->Type() == new_node_args[i]->Type());
|
||||
|
||||
|
|
@ -331,10 +336,10 @@ std::string AddEventBeforeNode(
|
|||
std::vector<Node*> nodes = {node};
|
||||
|
||||
// Replace node_args[i] with new_node_args[i] in nodes.
|
||||
ReplaceNodeArgs(nodes, node_args, new_node_args);
|
||||
ReplaceNodeArgs(nodes, node_args, new_node_args);
|
||||
|
||||
// Create node_arg for event ID.
|
||||
auto event_node_arg = &CreateInt64NodeArg(graph, event_id_name);
|
||||
auto event_node_arg = &CreateTypedNodeArg(graph, ONNX_NAMESPACE::TensorProto_DataType_INT64, event_id_name);
|
||||
|
||||
// Create node which produces new_node_args from event ID and node_args.
|
||||
CreateBottleneckNode(graph,
|
||||
|
|
@ -372,11 +377,11 @@ std::string AddEventAfterNode(
|
|||
std::vector<Node*> consumer_nodes = graph.GetMutableConsumerNodes(
|
||||
node_args[i]->Name());
|
||||
// Replace node_args[i] with new_node_args[i] in nodes.
|
||||
ReplaceNodeArgs(consumer_nodes, {node_args[i]}, {new_node_args[i]});
|
||||
ReplaceNodeArgs(consumer_nodes, {node_args[i]}, {new_node_args[i]});
|
||||
}
|
||||
|
||||
// Create node_arg for event ID.
|
||||
auto event_node_arg = &CreateInt64NodeArg(graph, event_id_name);
|
||||
auto event_node_arg = &CreateTypedNodeArg(graph, ONNX_NAMESPACE::TensorProto_DataType_INT64, event_id_name);
|
||||
|
||||
// Create node which produces new_node_args from event ID and node_args.
|
||||
CreateBottleneckNode(graph,
|
||||
|
|
@ -395,10 +400,10 @@ Status AddForwardWaitAfterRecv(
|
|||
Node* comm_node,
|
||||
std::vector<std::string>& new_input_names,
|
||||
std::string& event_name) {
|
||||
event_name = AddEventAfterNode(
|
||||
event_name = AddEventAfterNode(
|
||||
graph, comm_node,
|
||||
"WaitEvent", "forward_wait_after_recv",
|
||||
"forward_wait_after_recv_event_id");
|
||||
"forward_wait_after_recv_event_id");
|
||||
if (event_name.empty()) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
|
|
@ -458,6 +463,37 @@ Status AddBackwardRecordBeforeSend(
|
|||
}
|
||||
}
|
||||
|
||||
Status SetInputsOutputsAndResolve(Graph& graph,
|
||||
const std::vector<std::string>& new_input_names,
|
||||
const std::vector<std::string>& new_output_names) {
|
||||
auto fill_node_args = [&](const Graph& graph,
|
||||
const std::vector<const NodeArg*>& existed_node_args,
|
||||
const std::vector<std::string>& new_node_arg_names,
|
||||
std::vector<const NodeArg*>& merged_node_args) {
|
||||
merged_node_args.insert(merged_node_args.end(), existed_node_args.begin(), existed_node_args.end());
|
||||
for (auto& name : new_node_arg_names) {
|
||||
merged_node_args.push_back(graph.GetNodeArg(name));
|
||||
}
|
||||
};
|
||||
|
||||
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();
|
||||
std::vector<const NodeArg*> inputs_args_sets;
|
||||
inputs_args_sets.reserve(graph_inputs.size() + new_input_names.size());
|
||||
fill_node_args(graph, graph_inputs, new_input_names, inputs_args_sets);
|
||||
|
||||
const std::vector<const NodeArg*>& graph_outputs = graph.GetOutputs();
|
||||
std::vector<const NodeArg*> outputs_args_sets;
|
||||
outputs_args_sets.reserve(graph_outputs.size() + new_output_names.size());
|
||||
fill_node_args(graph, graph_outputs, new_output_names, outputs_args_sets);
|
||||
|
||||
graph.SetInputs(inputs_args_sets);
|
||||
graph.SetOutputs(outputs_args_sets);
|
||||
graph.SetGraphResolveNeeded();
|
||||
graph.SetGraphProtoSyncNeeded();
|
||||
|
||||
return graph.Resolve();
|
||||
}
|
||||
|
||||
// This function inserts WaitEvent's and RecordEvent's to the input graph for
|
||||
// controlling synchronization between (batch, pipeline stage)-pairs.
|
||||
//
|
||||
|
|
@ -570,7 +606,7 @@ Status TransformGraphForPipeline(
|
|||
backward_waited_event_name,
|
||||
forward_record_output_name,
|
||||
backward_wait_output_name));
|
||||
|
||||
|
||||
// Different stages have different patterns of Send & Recv.
|
||||
// For different patterns, we add different WaitEvent and Record.
|
||||
//
|
||||
|
|
@ -691,33 +727,357 @@ Status TransformGraphForPipeline(
|
|||
));
|
||||
}
|
||||
|
||||
auto fill_node_args = [&](const Graph& graph,
|
||||
const std::vector<const NodeArg*>& existed_node_args,
|
||||
std::vector<std::string>& new_node_arg_names,
|
||||
std::vector<const NodeArg*>& merged_node_args) {
|
||||
merged_node_args.insert(merged_node_args.end(), existed_node_args.begin(), existed_node_args.end());
|
||||
for (auto& name : new_node_arg_names) {
|
||||
merged_node_args.push_back(graph.GetNodeArg(name));
|
||||
ORT_RETURN_IF_ERROR(SetInputsOutputsAndResolve(graph, new_input_names, new_output_names));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// This function is used when you want to create a scalar constant in a graph.
|
||||
// It may create a NodeArg so that other Node can references its value.
|
||||
// It also cerates an initializer to store its value.
|
||||
template <typename T>
|
||||
void AddNewScalarNodeArgAndInitializer(Graph& graph,
|
||||
const std::string& op_name,
|
||||
onnx::TensorProto_DataType type,
|
||||
T data,
|
||||
std::vector<NodeArg*>& new_node_args,
|
||||
std::vector<std::string>& new_names) {
|
||||
AddNewNodeArg(graph, op_name, type, new_node_args, new_names);
|
||||
|
||||
ONNX_NAMESPACE::TensorProto proto_data;
|
||||
proto_data.set_name(new_names.back());
|
||||
proto_data.set_data_type(type);
|
||||
|
||||
switch (type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
proto_data.add_int32_data(static_cast<int32_t>(data));
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
|
||||
proto_data.add_int64_data(static_cast<int64_t>(data));
|
||||
break;
|
||||
default:
|
||||
ORT_THROW("pipeline partition unsupported 'type' value: ", type);
|
||||
}
|
||||
graph.AddInitializedTensor(proto_data);
|
||||
}
|
||||
|
||||
// split the graph into disconnected subgraph based on provided CutInfo
|
||||
common::Status SplitGraph(Graph& graph,
|
||||
std::vector<TrainingSession::TrainingConfiguration::CutInfo> split_edge_groups,
|
||||
std::vector<Node*>& send_nodes,
|
||||
std::vector<Node*>& recv_nodes) {
|
||||
std::vector<std::string> new_input_names;
|
||||
std::vector<std::string> new_output_names;
|
||||
|
||||
// updated_node_args keeps track of the mapping between the original node_arg and its corresponding updated
|
||||
// node_arg after send and recv node is added. As multiple partitions can happen, and a single node_arg
|
||||
// can belong to different partition, updated_node_args always keeps track of the latest updated node_arg.
|
||||
// Below is one example of how this works using update_node_args:
|
||||
// there are three edges in graph, specified as nodeA->nodeB, nodeA->nodeC, and nodeA->nodeD.
|
||||
// those edges all share the same node_arg.
|
||||
// but nodeA, nodeB belong to parition0, nodeC belongs to parition1, and nodeD belongs to parition2.
|
||||
// This means we need to cut edge nodeA->nodeC for the first partition and nodeA->nodeD for the second partition.
|
||||
//
|
||||
// During the first cut, we identify the edge nodeA->nodeC, for this edge, based on the origional node_arg,
|
||||
// we create a new node_arg, called updated_node_arg. The inserted send node will take the original node_arg
|
||||
// as input and the inserted recv node will take the updated_node_arg as the output.
|
||||
// And we update updated_node_args with updated_node_args[original_node_arg] = updated_node_arg
|
||||
//
|
||||
// Now during the second cut, we need to cut the edge nodeA->nodeD. Noted that as the cut is performed in sequential,
|
||||
// the second cut is performed based on the graph modified after the first cut. This means, the input node_arg for
|
||||
// nodeD shouldn't come from nodeA anymore, as nodeA now residents in partition0, which is a disconnected partition.
|
||||
// Instead, the input node_arg of nodeD should come from the updated version: updated_node_arg from partition1.
|
||||
// By using the updated_node_args map, we can retrieve updated_node_arg from original_node_arg, and use that as the
|
||||
// newly inserted send's input. Also, to keep this on going for any following cut, we create an updated_node_arg_v2,
|
||||
// and update updated_node_args with updated_node_args[original_node_arg] = updated_node_arg_v2
|
||||
std::map<NodeArg*, NodeArg*> updated_node_args;
|
||||
for (size_t index = 0; index < split_edge_groups.size(); ++index) {
|
||||
// each entry in split_edge_groups represents a partition cut. Each cut can contain the split of
|
||||
// several edges.
|
||||
auto& edgeIds = split_edge_groups[index];
|
||||
|
||||
// for each cut, record the inserted input/output args.
|
||||
std::vector<NodeArg*> send_input_args;
|
||||
std::vector<NodeArg*> send_output_args;
|
||||
std::vector<NodeArg*> recv_input_args;
|
||||
std::vector<NodeArg*> recv_output_args;
|
||||
|
||||
auto cut_index_str = std::to_string(index);
|
||||
// add input node_arg and initializer for send/recv
|
||||
AddNewScalarNodeArgAndInitializer<bool>(graph,
|
||||
"send_input_signal" + cut_index_str,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
|
||||
true, /* initializer data */
|
||||
send_input_args,
|
||||
new_input_names);
|
||||
AddNewScalarNodeArgAndInitializer<bool>(graph,
|
||||
"recv_input_signal" + cut_index_str,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
|
||||
true, /* initializer data */
|
||||
recv_input_args,
|
||||
new_input_names);
|
||||
|
||||
AddNewScalarNodeArgAndInitializer<size_t>(graph,
|
||||
"send_dst_rank" + cut_index_str,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT64,
|
||||
index + 1, /* initializer data */
|
||||
send_input_args,
|
||||
new_input_names);
|
||||
AddNewScalarNodeArgAndInitializer<size_t>(graph,
|
||||
"recv_src_rank" + cut_index_str,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT64,
|
||||
index, /* initializer data */
|
||||
recv_input_args,
|
||||
new_input_names);
|
||||
// add output node_arg for send/recv
|
||||
AddNewNodeArg(graph, "send_output_signal" + cut_index_str, ONNX_NAMESPACE::TensorProto_DataType_BOOL,
|
||||
send_output_args, new_output_names);
|
||||
AddNewNodeArg(graph, "receive_output_signal" + cut_index_str, ONNX_NAMESPACE::TensorProto_DataType_BOOL,
|
||||
recv_output_args, new_output_names);
|
||||
|
||||
// add attribute data for send/recv
|
||||
ONNX_NAMESPACE::AttributeProto tag;
|
||||
tag.set_name("tag");
|
||||
tag.set_type(ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INT);
|
||||
// currently hard-coded all tag to be 0. May need to change when multiple GPU stream is used.
|
||||
tag.set_i(static_cast<int64_t>(0));
|
||||
|
||||
ONNX_NAMESPACE::AttributeProto element_types;
|
||||
element_types.set_name("element_types");
|
||||
element_types.set_type(ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS);
|
||||
|
||||
// for each edge in this group, perform edge cut
|
||||
for (auto& id : edgeIds) {
|
||||
// find node whose output contains id.node_arg_name
|
||||
auto producer_node = graph.GetMutableProducerNode(id.node_arg_name);
|
||||
if (!producer_node) {
|
||||
ORT_THROW("Cannot find producer node of node_arg with name: ", id.node_arg_name, ". Wrong cutting infomation.");
|
||||
}
|
||||
|
||||
// once we find out the producer node for id.node_arg_name, find which output index that leads
|
||||
// to id.node_arg_name
|
||||
int upstream_nodes_output_index{-1};
|
||||
producer_node->ForEachWithIndex(
|
||||
producer_node->OutputDefs(),
|
||||
[&](const NodeArg& def, size_t index) {
|
||||
if (def.Name() == id.node_arg_name) {
|
||||
upstream_nodes_output_index = static_cast<int>(index);
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
if (upstream_nodes_output_index < 0) {
|
||||
ORT_THROW("Node with name: ", producer_node->Name(),
|
||||
" doesn't have an output node_arg with name ", id.node_arg_name);
|
||||
}
|
||||
|
||||
size_t idx = static_cast<size_t>(upstream_nodes_output_index);
|
||||
|
||||
// original node_arg pointer from the origin graph. This serves as the key in the
|
||||
// updated_node_arg map and any reference for original node_arg name
|
||||
auto* original_node_arg = producer_node->MutableOutputDefs()[idx];
|
||||
|
||||
// updated node_arg pointer from previous partition. This is the new arg_node the
|
||||
// current inserted send node will take as input node_arg.
|
||||
auto updated_node_arg = producer_node->MutableOutputDefs()[idx];
|
||||
auto exiting_updated_node_arg = updated_node_args.find(original_node_arg);
|
||||
if (exiting_updated_node_arg != updated_node_args.end()) {
|
||||
updated_node_arg = exiting_updated_node_arg->second;
|
||||
}
|
||||
|
||||
send_input_args.push_back(updated_node_arg);
|
||||
|
||||
auto dtype = original_node_arg->TypeAsProto()->tensor_type().elem_type();
|
||||
switch (dtype) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
|
||||
element_types.add_ints(static_cast<int64_t>(1));
|
||||
break;
|
||||
default:
|
||||
// Assume all tensors are of type float.
|
||||
// TODO: update if graph supports other data type.
|
||||
ORT_THROW("pipeline partition unsupported 'type' value: ", dtype);
|
||||
}
|
||||
|
||||
auto& new_receive_output = CreateNodeArg(graph, updated_node_arg);
|
||||
const auto old_shape = *(updated_node_arg->Shape());
|
||||
new_receive_output.SetShape(old_shape);
|
||||
recv_output_args.push_back(&new_receive_output);
|
||||
|
||||
// add value info for this newly added receive_output, for shape propagation
|
||||
// when training this partition.
|
||||
graph.AddValueInfo(&new_receive_output);
|
||||
|
||||
// update updated_node_args with the newly created node_arg
|
||||
updated_node_args[original_node_arg] = &new_receive_output;
|
||||
|
||||
// deal with shape inference for newly added edge
|
||||
auto& output_edge_name = original_node_arg->Name();
|
||||
|
||||
// deal with updating the consumer's input node_args
|
||||
std::vector<Node*> consumer_nodes;
|
||||
if (id.consumer_nodes.has_value()) {
|
||||
for(auto& consumer_node_id : id.consumer_nodes.value()){
|
||||
consumer_nodes.push_back(graph.GetMutableProducerNode(consumer_node_id));
|
||||
}
|
||||
} else {
|
||||
consumer_nodes = graph.GetMutableConsumerNodes(output_edge_name);
|
||||
}
|
||||
|
||||
for (auto consumer_node : consumer_nodes) {
|
||||
for (auto& input : consumer_node->MutableInputDefs()) {
|
||||
if (input->Name() == output_edge_name) {
|
||||
input = &new_receive_output;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
const int num_attributes = 2; // two attributes: tag and element_types
|
||||
NodeAttributes attributes;
|
||||
attributes.reserve(num_attributes);
|
||||
attributes[tag.name()] = tag;
|
||||
attributes[element_types.name()] = element_types;
|
||||
|
||||
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();
|
||||
std::vector<const NodeArg*> inputs_args_sets;
|
||||
inputs_args_sets.reserve(graph_inputs.size() + new_input_names.size());
|
||||
fill_node_args(graph, graph_inputs, new_input_names, inputs_args_sets);
|
||||
auto& send_node = graph.AddNode(graph.GenerateNodeName("Send"),
|
||||
"Send",
|
||||
"",
|
||||
send_input_args,
|
||||
send_output_args, /* output */
|
||||
&attributes, /* attribute */
|
||||
kMSDomain);
|
||||
|
||||
const std::vector<const NodeArg*>& graph_outputs = graph.GetOutputs();
|
||||
std::vector<const NodeArg*> outputs_args_sets;
|
||||
outputs_args_sets.reserve(graph_outputs.size() + new_output_names.size());
|
||||
fill_node_args(graph, graph_outputs, new_output_names, outputs_args_sets);
|
||||
send_nodes.push_back(&send_node);
|
||||
|
||||
graph.SetInputs(inputs_args_sets);
|
||||
graph.SetOutputs(outputs_args_sets);
|
||||
auto& recv_node = graph.AddNode(graph.GenerateNodeName("Recv"),
|
||||
"Recv",
|
||||
"",
|
||||
recv_input_args,
|
||||
recv_output_args, /* output */
|
||||
&attributes, /* attribute */
|
||||
kMSDomain);
|
||||
recv_nodes.push_back(&recv_node);
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(SetInputsOutputsAndResolve(graph, new_input_names, new_output_names));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FindAllConnectedNodes(Graph& graph,
|
||||
const Node* node,
|
||||
std::vector<const Node*>& connected_nodes,
|
||||
std::set<const NodeArg*>& connected_inputs,
|
||||
std::set<const NodeArg*>& connected_outputs) {
|
||||
ORT_THROW_IF_ERROR(node->ForEachWithIndex(
|
||||
node->InputDefs(),
|
||||
[&](const NodeArg& node_arg, size_t /*index*/) {
|
||||
if (graph.IsInputsIncludingInitializers(&node_arg)) {
|
||||
connected_inputs.insert(&node_arg);
|
||||
} else {
|
||||
const Node* producer_node = graph.GetProducerNode(node_arg.Name());
|
||||
if (producer_node == nullptr) {
|
||||
// got nullptr as producer node. This could be because the input is a constant op which will be optimized
|
||||
// away. Print out this information and continue.
|
||||
LOGS_DEFAULT(WARNING) << "Cannot find producer node for node_arg: " << node_arg.Name() << ". Skipping this node.";
|
||||
} else {
|
||||
connected_nodes.push_back(producer_node);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
ORT_THROW_IF_ERROR(node->ForEachWithIndex(
|
||||
node->OutputDefs(),
|
||||
[&](const NodeArg& node_arg, size_t /*index*/) {
|
||||
if (!graph.IsOutput(&node_arg)) {
|
||||
std::vector<const Node*> consumer_nodes = graph.GetConsumerNodes(node_arg.Name());
|
||||
connected_nodes.insert(std::end(connected_nodes), consumer_nodes.begin(), consumer_nodes.end());
|
||||
|
||||
} else {
|
||||
connected_outputs.insert(&node_arg);
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// traverse the graph from start_node to get the set of nodes contains in this disconnected subgraph
|
||||
common::Status GenerateSubgraph(Graph& graph, const Node* start_node) {
|
||||
std::queue<const Node*> node_queue;
|
||||
node_queue.push(start_node);
|
||||
|
||||
std::set<const Node*> visited_nodes;
|
||||
std::set<const NodeArg*> visited_inputs;
|
||||
std::set<const NodeArg*> visited_outputs;
|
||||
|
||||
// BFS graph traverse
|
||||
while (!node_queue.empty()) {
|
||||
auto node = node_queue.front();
|
||||
node_queue.pop();
|
||||
if (visited_nodes.count(node) == 0) {
|
||||
visited_nodes.insert(node);
|
||||
std::vector<const Node*> connected_nodes;
|
||||
ORT_THROW_IF_ERROR(FindAllConnectedNodes(graph, node, connected_nodes, visited_inputs, visited_outputs));
|
||||
|
||||
for (auto n : connected_nodes) {
|
||||
ORT_ENFORCE(n!=nullptr, "Found nullptr in searching for connected nodes");
|
||||
node_queue.push(n);
|
||||
}
|
||||
}
|
||||
}
|
||||
std::set<NodeIndex> visited_node_index;
|
||||
for (auto n : visited_nodes) {
|
||||
visited_node_index.insert(n->Index());
|
||||
}
|
||||
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
// reverse iterate the nodes in tolopogical order, and delete those not visited
|
||||
for (auto it = node_topology_list.rbegin(); it != node_topology_list.rend(); it++){
|
||||
if (visited_node_index.count(*it)==0){
|
||||
graph.RemoveNode(*it);
|
||||
}
|
||||
}
|
||||
|
||||
// update the grah with only visited inputs and outputs
|
||||
graph.SetInputs({visited_inputs.begin(), visited_inputs.end()});
|
||||
graph.SetOutputs({visited_outputs.begin(), visited_outputs.end()});
|
||||
graph.SetGraphResolveNeeded();
|
||||
graph.SetGraphProtoSyncNeeded();
|
||||
|
||||
return graph.Resolve();
|
||||
}
|
||||
|
||||
Status ApplyPipelinePartitionToMainGraph(
|
||||
Graph& graph,
|
||||
const std::vector<TrainingSession::TrainingConfiguration::CutInfo>& cut_info,
|
||||
size_t pipeline_stage_id,
|
||||
size_t num_pipeline_stage) {
|
||||
size_t split_count = cut_info.size();
|
||||
|
||||
if (num_pipeline_stage != split_count + 1) {
|
||||
ORT_THROW("Wrong pipeline partition cutting info. Total pipeline stage number is ",
|
||||
num_pipeline_stage,
|
||||
", cut info length is: ",
|
||||
split_count);
|
||||
}
|
||||
|
||||
std::vector<Node *> send_nodes, recv_nodes;
|
||||
send_nodes.reserve(split_count);
|
||||
recv_nodes.reserve(split_count);
|
||||
ORT_RETURN_IF_ERROR(SplitGraph(graph, cut_info, send_nodes, recv_nodes));
|
||||
|
||||
if (send_nodes.size() != split_count || recv_nodes.size() != split_count) {
|
||||
ORT_THROW("Split error: not all cut has Send and Recv inserted. Send node count: ",
|
||||
send_nodes.size(), ", Recv node count: ", recv_nodes.size(), ", split count: ", split_count);
|
||||
}
|
||||
|
||||
if (pipeline_stage_id < split_count) {
|
||||
ORT_RETURN_IF_ERROR(GenerateSubgraph(graph, send_nodes[pipeline_stage_id]));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(GenerateSubgraph(graph, recv_nodes.back()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -4,12 +4,14 @@
|
|||
#pragma once
|
||||
|
||||
#include "core/graph/graph.h"
|
||||
#include "orttraining/core/session/training_session.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
||||
void GetPipelineSendOutput(const Graph& graph, std::string& loss_name);
|
||||
common::Status TransformGraphForPipeline(
|
||||
|
||||
Status TransformGraphForPipeline(
|
||||
Graph& graph,
|
||||
std::string& forward_waited_event_name,
|
||||
std::string& forward_recorded_event_name,
|
||||
|
|
@ -24,5 +26,10 @@ common::Status TransformGraphForPipeline(
|
|||
std::string& backward_waited_event_after_recv_name,
|
||||
std::string& backward_recorded_event_before_send_name);
|
||||
|
||||
Status ApplyPipelinePartitionToMainGraph(
|
||||
Graph& graph,
|
||||
const std::vector<TrainingSession::TrainingConfiguration::CutInfo>& cut_info,
|
||||
size_t pipeline_stage_id,
|
||||
size_t num_pipeline_stage);
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -141,6 +141,15 @@ Status TrainingSession::ConfigureForTraining(
|
|||
excluded_initializers.erase(weight_name_to_not_train);
|
||||
}
|
||||
|
||||
if (config.pipeline_config.has_value() && config.pipeline_config.value().do_partition) {
|
||||
// Apply online pipeline partition to graph obj. This needs to be done first before any graph
|
||||
// transportation which may alter node_arg and invalidate cut_list info from the original graph.
|
||||
ORT_RETURN_IF_ERROR(ApplyPipelinePartitionToMainGraph(model_->MainGraph(),
|
||||
config.pipeline_config.value().cut_list,
|
||||
config.distributed_config.world_rank,
|
||||
config.distributed_config.world_size));
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(excluded_initializers));
|
||||
|
||||
is_mixed_precision_enabled_ = config.mixed_precision_config.has_value();
|
||||
|
|
@ -233,7 +242,7 @@ Status TrainingSession::ConfigureForTraining(
|
|||
}
|
||||
pipeline_result.fetch_names.push_back(name);
|
||||
}
|
||||
pipeline_result.pipeline_stage_id = config.distributed_config.world_rank /
|
||||
pipeline_result.pipeline_stage_id = config.distributed_config.world_rank /
|
||||
(config.distributed_config.data_parallel_size * config.distributed_config.horizontal_parallel_size);
|
||||
config_result.pipeline_config_result = pipeline_result;
|
||||
}
|
||||
|
|
@ -307,13 +316,19 @@ Status TrainingSession::ConfigureForTraining(
|
|||
tensorboard_config.histogram_node_names, tensorboard_config.norm_node_names,
|
||||
tensorboard_config.dump_convergence_metrics));
|
||||
}
|
||||
|
||||
|
||||
// add GIST encoding
|
||||
if (config.gist_config.has_value()) {
|
||||
ORT_RETURN_IF_ERROR(AddGistEncoding());
|
||||
}
|
||||
|
||||
if (IsRootNode(config) && config.model_with_training_graph_path.has_value()) {
|
||||
// If the current node is in rank0 or if the current session is running pipeline (in which case different rank would
|
||||
// store different model partition), and if model_with_training_graph_path is specified, save the model.
|
||||
// Note: in the pipeline case, different ranks may resident in the same node. This could lead to a potential write
|
||||
// conflict. It is user's responsibility to make sure different rank is passed in with different
|
||||
// model_with_training_graph_path value.
|
||||
if ((IsRootNode(config) || config.pipeline_config.has_value())
|
||||
&& config.model_with_training_graph_path.has_value()) {
|
||||
ORT_IGNORE_RETURN_VALUE(Save(
|
||||
config.model_with_training_graph_path.value(), SaveOption::NO_RELOAD));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -135,6 +135,28 @@ class TrainingSession : public InferenceSession {
|
|||
// If not provided, no optimizer is added.
|
||||
optional<OptimizerConfiguration> optimizer_config{};
|
||||
|
||||
// struct to describe a specific edge. An edge is not the same as a node_arg. Edge represents a connection between two operators.
|
||||
// For example, an operator A's output tensor T is connecting to another operator B's input, then this constructs
|
||||
// an edge from A to B. If A's output tensor T has multiple consumers, i.e. it's fed into multiple operators' inputs,
|
||||
// there would be multiple edges, each from A, to a consumer operator.
|
||||
// CutEdge information is used in pipeline online partition tool to identify which edge to cut to make the
|
||||
// corresponding partition.
|
||||
struct CutEdge {
|
||||
std::string node_arg_name;
|
||||
optional<std::vector<std::string>> consumer_nodes;
|
||||
|
||||
// If the edge is unique, i.e. only have one consumer node, or all the edges
|
||||
// with the same node_arg_name needs to be cut, specify the node_arg_name
|
||||
// suffices.
|
||||
CutEdge(std::string edge) : node_arg_name(edge){};
|
||||
// If the edges with same node_arg_name belongs to different cut, i.e. some of its
|
||||
// consumer node belongs to one partition, and some belongs to another, specify
|
||||
// the consumer node names which you want to perform the cut on.
|
||||
CutEdge(std::string edge, std::vector<std::string> nodes) : node_arg_name(edge), consumer_nodes(nodes){};
|
||||
};
|
||||
// CutInfo is a group of CutEdges that describes a specific cut that composed of splitting those edges.
|
||||
typedef std::vector<CutEdge> CutInfo;
|
||||
|
||||
struct PipelineConfiguration {
|
||||
// If model partition happens outside ORT, this flag should be false.
|
||||
// Otherwise, use true to trigger ORT's pipeline partition.
|
||||
|
|
@ -142,7 +164,9 @@ class TrainingSession : public InferenceSession {
|
|||
// Tensors to fetch as specified by the user.
|
||||
// Each pipeline stage should pick up some strings from this field..
|
||||
std::vector<std::string> fetch_names;
|
||||
// [TODO] Add cut information.
|
||||
// cut_list contains the list of CutInfo to make the graph partitions.
|
||||
// cut_list[i] contains the CutInfo to make the partition between stage i and stage i+1
|
||||
std::vector<CutInfo> cut_list;
|
||||
};
|
||||
|
||||
// If pipeline is enabled, this field's has_value() returns true.
|
||||
|
|
@ -329,7 +353,7 @@ class TrainingSession : public InferenceSession {
|
|||
// Insert operators for running pipeline and return event tensor names.
|
||||
// For an intermediate pipeline stage, its original computation is
|
||||
//
|
||||
// Recv --> Forward --> Send -->
|
||||
// Recv --> Forward --> Send -->
|
||||
// Recv --> Backward --> Send
|
||||
//
|
||||
// After this function, the resulted computation is
|
||||
|
|
@ -340,7 +364,7 @@ class TrainingSession : public InferenceSession {
|
|||
// As you can see, some event operators are inserted. For each event operator, its dependent
|
||||
// event tensor name is written to an input references, for example, "forward_waited_event_name".
|
||||
//
|
||||
// This function assumes that
|
||||
// This function assumes that
|
||||
// 1. Only one Recv and only one Send present in forward pass.
|
||||
// 2. Only one Recv and only one Send present in backward pass.
|
||||
// 3. Backward operators' descriptions are all "Backward pass". This assumption is used to
|
||||
|
|
|
|||
|
|
@ -155,6 +155,11 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet
|
|||
("horizontal_parallel_size", "Horizontal model parallel group size.", cxxopts::value<int>()->default_value("1"))
|
||||
("pipeline_parallel_size", "Number of pipeline stages.", cxxopts::value<int>()->default_value("1"))
|
||||
("pipeline_stage_paths", "Specify the forward ONNX files for pipeline evaluation.", cxxopts::value<std::vector<std::string>>()->default_value(""))
|
||||
("cut_group_info", "Specify the cutting info for graph partition (pipeline only). An example of a cut_group_info of "
|
||||
"size two is: 1393:407-1463/1585/1707,2369:407-2439/2561/2683. Here, the cut info is split by ',', with the first "
|
||||
"cut_info equal to 1393:407-1463/1585/1707, and second cut_info equal to 2369:407-2439/2561/2683. Each CutEdge is "
|
||||
"seperated by ':'. If consumer nodes need to be specified, specify them after producer node with a '-' delimiter and "
|
||||
"separate each consumer node with a '/'. ", cxxopts::value<std::vector<std::string>>()->default_value(""))
|
||||
("enable_grad_norm_clip", "Specify whether to enable gradient clipping for optimizers.",
|
||||
cxxopts::value<bool>()->default_value("true"));
|
||||
options
|
||||
|
|
@ -379,6 +384,58 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet
|
|||
// Backward pass and optimizer nodes are implicitly generated by ORT.
|
||||
params.pipeline_stage_paths = flags["pipeline_stage_paths"].as<std::vector<std::string>>();
|
||||
|
||||
// If user doesn't provide partitioned model files, a cut list should be provided for ORT to do partition
|
||||
// online. If the pipeline contains n stages, the cut list should be of length (n-1), in order to cut the
|
||||
// graph into n partitions.
|
||||
if (params.pipeline_parallel_size > 1 && params.pipeline_stage_paths.empty()) {
|
||||
auto cut_info_groups = flags["cut_group_info"].as<std::vector<std::string>>();
|
||||
|
||||
ORT_RETURN_IF_NOT(static_cast<int>(cut_info_groups.size() + 1) == params.pipeline_parallel_size,
|
||||
"cut_info length plus one must match pipeline parallel size");
|
||||
|
||||
auto process_with_delimiter = [](std::string& input_str, const std::string& delimiter) {
|
||||
std::vector<std::string> result;
|
||||
size_t pos = 0;
|
||||
std::string token;
|
||||
while ((pos = input_str.find(delimiter)) != std::string::npos) {
|
||||
token = input_str.substr(0, pos);
|
||||
result.emplace_back(token);
|
||||
input_str.erase(0, pos + delimiter.length());
|
||||
}
|
||||
// push the last split of substring into result.
|
||||
result.emplace_back(input_str);
|
||||
return result;
|
||||
};
|
||||
|
||||
auto process_cut_info = [&](std::string& cut_info_string) {
|
||||
TrainingSession::TrainingConfiguration::CutInfo cut_info;
|
||||
const std::string edge_delimiter = ":";
|
||||
const std::string consumer_delimiter = "/";
|
||||
const std::string producer_consumer_delimiter = "-";
|
||||
|
||||
auto cut_edges = process_with_delimiter(cut_info_string, edge_delimiter);
|
||||
for (auto& cut_edge : cut_edges) {
|
||||
auto process_edge = process_with_delimiter(cut_edge, producer_consumer_delimiter);
|
||||
if (process_edge.size() == 1) {
|
||||
TrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0]};
|
||||
cut_info.emplace_back(edge);
|
||||
} else {
|
||||
ORT_ENFORCE(process_edge.size() == 2);
|
||||
auto consumer_list = process_with_delimiter(process_edge[1], consumer_delimiter);
|
||||
|
||||
TrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0], consumer_list};
|
||||
cut_info.emplace_back(edge);
|
||||
}
|
||||
}
|
||||
return cut_info;
|
||||
};
|
||||
|
||||
for (auto& cut_info : cut_info_groups) {
|
||||
TrainingSession::TrainingConfiguration::CutInfo cut = process_cut_info(cut_info);
|
||||
params.pipeline_partition_cut_list.emplace_back(cut);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t seed = flags["seed"].as<int64_t>();
|
||||
if (params.horizontal_parallel_size > 1 && seed <= 0) {
|
||||
seed = 8211; // Megatron needs a random seed.
|
||||
|
|
|
|||
|
|
@ -158,6 +158,7 @@ Status TrainingRunner::Initialize() {
|
|||
// the session already loads a pipeline stage.
|
||||
pipe.do_partition = params_.pipeline_stage_paths.empty() ? true : false;
|
||||
pipe.fetch_names = params_.fetch_names;
|
||||
pipe.cut_list = params_.pipeline_partition_cut_list;
|
||||
// Do not assign value to config.pipeline_config if pipeline is not used.
|
||||
config.pipeline_config = pipe;
|
||||
}
|
||||
|
|
@ -321,7 +322,7 @@ Status TrainingRunner::Initialize() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TrainingRunner::Run(IDataLoader* training_data_loader, IDataLoader* test_data_loader,
|
||||
Status TrainingRunner::Run(IDataLoader* training_data_loader, IDataLoader* test_data_loader,
|
||||
const MapStringToString& mapped_dimensions) {
|
||||
if (params_.mpi_context.world_rank == 0 && !params_.model_actual_running_graph_path.empty()) {
|
||||
session_.Save(params_.model_actual_running_graph_path, TrainingSession::SaveOption::NO_RELOAD);
|
||||
|
|
@ -528,7 +529,7 @@ Status TrainingRunner::PrepareFetchNamesAndFetches(const SessionMode mode,
|
|||
const auto& allowed_fetch_names = pipeline_context_.fetch_names;
|
||||
|
||||
if (mode == ModelUpdateStep) {
|
||||
// Set up tensor to be fetched when doing model update.
|
||||
// Set up tensor to be fetched when doing model update.
|
||||
|
||||
if (params_.pipeline_parallel_size > 1) {
|
||||
// If pipeline is used, we need to filter out fetches which are not in this pipeline stage.
|
||||
|
|
@ -557,7 +558,7 @@ Status TrainingRunner::PrepareFetchNamesAndFetches(const SessionMode mode,
|
|||
}
|
||||
}
|
||||
} else if (mode == GradientAccumulateStep) {
|
||||
// Set up tensor to be fetched when doing gradient accumulation.
|
||||
// Set up tensor to be fetched when doing gradient accumulation.
|
||||
|
||||
if (params_.gradient_accumulation_steps > 1) {
|
||||
auto it = opt_graph_outputs_.find(OptimizerOutputKey::GradientAccumulation);
|
||||
|
|
@ -583,7 +584,7 @@ Status TrainingRunner::PrepareFetchNamesAndFetches(const SessionMode mode,
|
|||
}
|
||||
} else if (mode == EvaluateStep) {
|
||||
// Set up tensor to be fetched when doing model evaluation.
|
||||
// Ideally, this path should not fetch optimizer and gradient accumulation.
|
||||
// Ideally, this path should not fetch optimizer and gradient accumulation.
|
||||
// This path may fetch predicted scores, loss value, and so on.
|
||||
|
||||
if (params_.pipeline_parallel_size > 1) {
|
||||
|
|
@ -658,7 +659,7 @@ Status TrainingRunner::RunWithUpdate(VectorString& feed_names,
|
|||
}
|
||||
}
|
||||
|
||||
// Wait all workers to finish this around of pipeline parallism.
|
||||
// Wait all workers to finish this around of pipeline parallism.
|
||||
// The last batch in a pipeline collects gradient and update the model.
|
||||
pipeline_worker_pool_.JoinAll();
|
||||
|
||||
|
|
@ -802,7 +803,8 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad
|
|||
fetch_names,
|
||||
fetches);
|
||||
RunWithoutUpdate(feed_names, fetch_names, feeds,
|
||||
gradient_accumulation_step_count);
|
||||
gradient_accumulation_step_count);
|
||||
|
||||
}
|
||||
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
|
|
@ -880,7 +882,7 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad
|
|||
const size_t peak_workingset_size = perftest::utils::GetPeakWorkingSetSize();
|
||||
ORT_RETURN_IF_ERROR(Env::Default().CreateFolder(params_.perf_output_dir));
|
||||
// saving json file
|
||||
ORT_RETURN_IF_ERROR(SavePerfMetrics(number_of_batches, gradient_accumulation_step_count, weight_update_steps,
|
||||
ORT_RETURN_IF_ERROR(SavePerfMetrics(number_of_batches, gradient_accumulation_step_count, weight_update_steps,
|
||||
total_time, avg_time_per_batch, throughput, stabilized_throughput, mapped_dimensions,
|
||||
average_cpu_usage, peak_workingset_size));
|
||||
}
|
||||
|
|
@ -901,17 +903,17 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad
|
|||
Status TrainingRunner::SavePerfMetrics(const size_t number_of_batches, const size_t gradient_accumulation_steps,
|
||||
const size_t weight_update_steps, const double total_time,
|
||||
const double avg_time_per_batch, const double throughput, const double stabilized_throughput,
|
||||
const MapStringToString& mapped_dimensions,
|
||||
const MapStringToString& mapped_dimensions,
|
||||
const short average_cpu_usage, const size_t peak_workingset_size) {
|
||||
// populate metrics for reporting
|
||||
json perf_metrics;
|
||||
perf_metrics["Model"] = params_.model_type;
|
||||
perf_metrics["Model"] = params_.model_type;
|
||||
|
||||
// loop thru the mapped_dimensions and put it in json sub-structure
|
||||
std::string seq_len;
|
||||
for (auto const& it : mapped_dimensions) {
|
||||
if (it.first == "SeqLen") {
|
||||
seq_len = it.second;
|
||||
seq_len = it.second;
|
||||
}
|
||||
perf_metrics["DerivedProperties"][it.first] = it.second;
|
||||
}
|
||||
|
|
@ -929,7 +931,7 @@ Status TrainingRunner::SavePerfMetrics(const size_t number_of_batches, const siz
|
|||
|
||||
std::string optimizer = params_.training_optimizer_name;
|
||||
std::size_t pos = optimizer.find("Optimizer");
|
||||
if (pos != std::string::npos)
|
||||
if (pos != std::string::npos)
|
||||
optimizer = optimizer.substr(0, pos);
|
||||
perf_metrics["Optimizer"] = optimizer;
|
||||
|
||||
|
|
@ -948,7 +950,7 @@ Status TrainingRunner::SavePerfMetrics(const size_t number_of_batches, const siz
|
|||
|
||||
//
|
||||
// we will get date/time and commitId in post-run pipeline
|
||||
//
|
||||
//
|
||||
|
||||
// populate other basic params for bookkeeping - add more as needed
|
||||
json bookkeeping_params;
|
||||
|
|
@ -962,7 +964,7 @@ Status TrainingRunner::SavePerfMetrics(const size_t number_of_batches, const siz
|
|||
|
||||
perf_metrics["RunConfig"] = bookkeeping_params.dump(); // serialize the params as json string
|
||||
|
||||
std::string json_string = perf_metrics.dump();
|
||||
std::string json_string = perf_metrics.dump();
|
||||
|
||||
// write to a file - the next task in CI will pick up all files with the same prefix
|
||||
const PathString perf_metrics_path =
|
||||
|
|
@ -1072,7 +1074,7 @@ Status TrainingRunner::Evaluate(InferenceSession& session, IDataLoader& data_loa
|
|||
fetch_names,
|
||||
&fetches));
|
||||
|
||||
// Assume that user-specified fetches are avaliable only on the last pipeline stage.
|
||||
// Assume that user-specified fetches are avaliable only on the last pipeline stage.
|
||||
// When there is no pipeline, all pipeline_context_.pipeline_stage_id should be 0 and
|
||||
// params_.pipeline_parallel_size is 1. Thus, the following condition is always true if there
|
||||
// is no pipeline.
|
||||
|
|
|
|||
|
|
@ -156,8 +156,11 @@ class TrainingRunner {
|
|||
// pipeline_parallel_size > 1 means pipeline is enabled.
|
||||
// pipeline_parallel_size == 1 means pipeline is disabled.
|
||||
int pipeline_parallel_size = 1;
|
||||
// pipeline partition information to do online-partition. If the graph is
|
||||
// pre-partitioned, no need to fill this value.
|
||||
std::vector<TrainingSession::TrainingConfiguration::CutInfo> pipeline_partition_cut_list;
|
||||
// model_paths[i] is the name of the pipeline stage for i-th process.
|
||||
// The i-th file is run by the i-th MPI rank.
|
||||
// The i-th file is run by the i-th MPI rank.
|
||||
// If model_paths is not empty, model partition transformation may not be internally invoked.
|
||||
VectorString pipeline_stage_paths;
|
||||
// Enable gradient clipping.
|
||||
|
|
@ -169,7 +172,7 @@ class TrainingRunner {
|
|||
|
||||
common::Status Initialize();
|
||||
|
||||
common::Status Run(IDataLoader* training_data_loader, IDataLoader* test_data_loader,
|
||||
common::Status Run(IDataLoader* training_data_loader, IDataLoader* test_data_loader,
|
||||
const MapStringToString& mapped_dimensions = {});
|
||||
|
||||
common::Status EndTraining(IDataLoader* data_loader);
|
||||
|
|
@ -196,12 +199,12 @@ class TrainingRunner {
|
|||
Status RunWithUpdate(VectorString& feed_names,
|
||||
VectorString& fetch_names,
|
||||
std::vector<MLValue>& feeds,
|
||||
std::vector<MLValue>& fetches);
|
||||
std::vector<MLValue>& fetches);
|
||||
Status RunWithoutUpdate(VectorString& feed_names,
|
||||
VectorString& fetch_names,
|
||||
std::vector<MLValue>& feeds,
|
||||
size_t& gradient_accumulation_step_count);
|
||||
Status TrainingLoop(IDataLoader& training_data_loader, IDataLoader* test_data_loader,
|
||||
size_t& gradient_accumulation_step_count);
|
||||
Status TrainingLoop(IDataLoader& training_data_loader, IDataLoader* test_data_loader,
|
||||
const MapStringToString& mapped_dimensions);
|
||||
Status Evaluate(InferenceSession& session, IDataLoader& data_loader);
|
||||
|
||||
|
|
@ -230,7 +233,7 @@ class TrainingRunner {
|
|||
AllocatorPtr input_allocator_;
|
||||
|
||||
std::unique_ptr<CheckpointRegistry> checkpoint_registry_;
|
||||
|
||||
|
||||
// Pipeline fields are valid only if params_.pipeline_parallel_size > 1.
|
||||
// Information for running pipeline.
|
||||
pipeline::PipelineContext pipeline_context_;
|
||||
|
|
|
|||
|
|
@ -998,7 +998,7 @@ class PipelineBatchPlanner {
|
|||
};
|
||||
|
||||
void RetrieveEventOperators(
|
||||
Graph& graph,
|
||||
Graph& graph,
|
||||
Node** forward_wait_before_recv,
|
||||
Node** forward_wait_after_recv,
|
||||
Node** forward_record_before_send,
|
||||
|
|
@ -1065,7 +1065,7 @@ void RetrieveEventOperators(
|
|||
}
|
||||
|
||||
void RetrieveSendRecvOperators(
|
||||
Graph& graph,
|
||||
Graph& graph,
|
||||
Node** forward_recv,
|
||||
Node** forward_send,
|
||||
Node** backward_recv,
|
||||
|
|
@ -1089,9 +1089,9 @@ void RetrieveSendRecvOperators(
|
|||
if (is_backward(node)) {
|
||||
// backward_send can only be assigned one valid pointer.
|
||||
// If it is assigned more than once, it means we have multiple
|
||||
// Send in backward pass and therefore our assumption doesn't hold.
|
||||
// Send in backward pass and therefore our assumption doesn't hold.
|
||||
// This check ensure that only we only update *backward_send when
|
||||
// its value is NULL and guards our one-Recv assumption.
|
||||
// its value is NULL and guards our one-Recv assumption.
|
||||
ASSERT_TRUE(!(*backward_send));
|
||||
*backward_send = &node;
|
||||
} else {
|
||||
|
|
@ -1115,6 +1115,46 @@ void RetrieveSendRecvOperators(
|
|||
}
|
||||
}
|
||||
|
||||
TEST(GradientGraphBuilderTest, PipelineOnlinePartition) {
|
||||
auto model_uri = ORIGINAL_MODEL_PATH;
|
||||
|
||||
TrainingSession::TrainingConfiguration::PipelineConfiguration pipe{};
|
||||
pipe.do_partition = true;
|
||||
|
||||
// evenly cut the MLP model in 3 partitions
|
||||
TrainingSession::TrainingConfiguration::CutInfo cut0 = {TrainingSession::TrainingConfiguration::CutEdge("T3")};
|
||||
TrainingSession::TrainingConfiguration::CutInfo cut1 = {TrainingSession::TrainingConfiguration::CutEdge("T6")};
|
||||
pipe.cut_list.emplace_back(cut0);
|
||||
pipe.cut_list.emplace_back(cut1);
|
||||
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
#ifdef _WIN32
|
||||
auto surfix = std::to_wstring(i);
|
||||
#else
|
||||
auto surfix = std::to_string(i);
|
||||
#endif
|
||||
PathString output_file = ORT_TSTR("pipeline_partition_") + surfix + ORT_TSTR("_back.onnx");
|
||||
|
||||
auto config = MakeBasicTrainingConfig();
|
||||
config.pipeline_config = pipe;
|
||||
config.distributed_config.world_rank = i;
|
||||
config.distributed_config.world_size = 3;
|
||||
config.distributed_config.local_rank = i;
|
||||
config.distributed_config.local_size = 3;
|
||||
config.distributed_config.data_parallel_size = 1;
|
||||
config.distributed_config.horizontal_parallel_size = 1;
|
||||
config.distributed_config.pipeline_parallel_size = 3;
|
||||
config.model_with_training_graph_path = output_file;
|
||||
|
||||
PathString backprop_model_file;
|
||||
ASSERT_STATUS_OK(BuildBackPropGraph(model_uri, config, backprop_model_file));
|
||||
|
||||
std::shared_ptr<Model> model;
|
||||
// Ensure the partitioned model load.
|
||||
ASSERT_STATUS_OK(Model::Load(backprop_model_file, model, nullptr, DefaultLoggingManager().DefaultLogger()));
|
||||
}
|
||||
}
|
||||
|
||||
// verify pipeline config can load and gradient graph can construct.
|
||||
TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) {
|
||||
PathString filename_base = ORT_TSTR("testdata/test_training_model_");
|
||||
|
|
@ -1197,7 +1237,7 @@ TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) {
|
|||
Node* backward_send{nullptr};
|
||||
|
||||
RetrieveSendRecvOperators(
|
||||
graph,
|
||||
graph,
|
||||
&forward_recv,
|
||||
&forward_send,
|
||||
&backward_recv,
|
||||
|
|
|
|||
Loading…
Reference in a new issue