mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
* Address comments from #3823 and polish code * One line
This commit is contained in:
parent
4ff73d00b0
commit
0d11649bb3
3 changed files with 107 additions and 72 deletions
|
|
@ -30,11 +30,11 @@ NodeArg& CreateInt64NodeArg(Graph& graph, const std::string& name) {
|
|||
return node_arg;
|
||||
}
|
||||
|
||||
void AddInputEvent(Graph& graph, const std::string& op_name,
|
||||
bool is_forward,
|
||||
void AddInputEvent(Graph& graph,
|
||||
const std::string& event_name,
|
||||
std::vector<NodeArg*>& input_args,
|
||||
std::vector<std::string>& new_input_names) {
|
||||
auto& event_id = CreateInt64NodeArg(graph, op_name + (is_forward ? "_fw" : "_bw") + "_event_id");
|
||||
auto& event_id = CreateInt64NodeArg(graph, event_name);
|
||||
new_input_names.push_back(event_id.Name());
|
||||
input_args.push_back(&event_id);
|
||||
}
|
||||
|
|
@ -91,7 +91,7 @@ std::vector<NodeArg*> CreateMirrorNodeArgs(
|
|||
|
||||
// Create a node with input schema [event, input1, input2, ..., inputN] and
|
||||
// output schema [input1, input2, ..., inputN]
|
||||
void CreateBottleneckNode(Graph& graph,
|
||||
Node& CreateBottleneckNode(Graph& graph,
|
||||
const std::string& op_type,
|
||||
const std::string& op_name,
|
||||
const std::string& description,
|
||||
|
|
@ -102,7 +102,8 @@ void CreateBottleneckNode(Graph& graph,
|
|||
if (event) {
|
||||
input_node_args.insert(input_node_args.begin(), event);
|
||||
}
|
||||
graph.AddNode(
|
||||
|
||||
return graph.AddNode(
|
||||
name,
|
||||
op_type,
|
||||
description,
|
||||
|
|
@ -112,14 +113,14 @@ void CreateBottleneckNode(Graph& graph,
|
|||
kMSDomain);
|
||||
}
|
||||
|
||||
Node* AddRecordBackward(Graph& graph,
|
||||
Node* AddBackwardRecord(Graph& graph,
|
||||
Node* backward_send,
|
||||
std::vector<std::string>& new_input_names,
|
||||
std::vector<std::string>& new_output_names,
|
||||
std::string &event_id_tensor_name,
|
||||
std::string &output_tensor_name) {
|
||||
std::vector<NodeArg*> input_args;
|
||||
AddInputEvent(graph, "RecordEvent", false /* is_forward */, input_args, new_input_names);
|
||||
AddInputEvent(graph, "backward_recorded_event_id", input_args, new_input_names);
|
||||
std::vector<NodeArg*> output_args{};
|
||||
|
||||
if (backward_send) {
|
||||
|
|
@ -138,14 +139,9 @@ Node* AddRecordBackward(Graph& graph,
|
|||
output_args.push_back(&new_output);
|
||||
new_output_names.push_back(new_output.Name());
|
||||
|
||||
Node* record_node = &(graph.AddNode(
|
||||
graph.GenerateNodeName("RecordEvent"),
|
||||
"RecordEvent",
|
||||
"Backward pass",
|
||||
input_args,
|
||||
output_args,
|
||||
nullptr,
|
||||
kMSDomain));
|
||||
Node* record_node = &CreateBottleneckNode(
|
||||
graph, "RecordEvent", "backward_record", "Backward pass", nullptr,
|
||||
input_args, output_args);
|
||||
|
||||
// First input argument is the recorded event ID tensor.
|
||||
event_id_tensor_name = input_args.front()->Name();
|
||||
|
|
@ -156,7 +152,7 @@ Node* AddRecordBackward(Graph& graph,
|
|||
return record_node;
|
||||
}
|
||||
|
||||
Node* AddWaitForward(Graph& graph,
|
||||
Node* AddForwardWait(Graph& graph,
|
||||
Node* /* forward_recv */,
|
||||
std::vector<std::string>& new_input_names,
|
||||
std::string& forward_waited_event_name,
|
||||
|
|
@ -179,7 +175,7 @@ Node* AddWaitForward(Graph& graph,
|
|||
|
||||
std::vector<NodeArg*> input_args;
|
||||
std::vector<NodeArg*> output_args;
|
||||
AddInputEvent(graph, "WaitEvent", true /* is_forward */, input_args, new_input_names);
|
||||
AddInputEvent(graph, "forward_waited_event_id", input_args, new_input_names);
|
||||
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();
|
||||
|
||||
if (graph_inputs.size() == 0){
|
||||
|
|
@ -199,20 +195,20 @@ Node* AddWaitForward(Graph& graph,
|
|||
}
|
||||
}
|
||||
}
|
||||
Node* wait_node = &(graph.AddNode(
|
||||
graph.GenerateNodeName("WaitEvent"),
|
||||
"WaitEvent",
|
||||
"",
|
||||
input_args,
|
||||
output_args,
|
||||
nullptr,
|
||||
kMSDomain));
|
||||
|
||||
Node* wait_node = &CreateBottleneckNode(
|
||||
graph, "WaitEvent", "backward_record", "", nullptr,
|
||||
input_args, output_args);
|
||||
|
||||
forward_waited_event_name = input_args.front()->Name();
|
||||
output_tensor_name = output_args.front()->Name();
|
||||
|
||||
return wait_node;
|
||||
}
|
||||
|
||||
Status AddOrSkipRecordForwardWaitBackward(Graph& graph,
|
||||
// If the input "graph" is the last pipeline stage, this function don't add any
|
||||
// event operators.
|
||||
Status AddOrSkipForwardRecordBackwardWait(Graph& graph,
|
||||
Node* forward_send,
|
||||
Node* backward_recv,
|
||||
std::vector<std::string>& new_input_names,
|
||||
|
|
@ -227,11 +223,13 @@ Status AddOrSkipRecordForwardWaitBackward(Graph& graph,
|
|||
}
|
||||
|
||||
if (!forward_send && !backward_recv){
|
||||
// Last partition doesn't have send forwrad and recv backward. No insert needed.
|
||||
// Last partition doesn't have send forwrad and recv backward. No insert
|
||||
// needed.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// if we have a send forward op followed by a recv backward op, insert WaitEvent and RecordEvent in between.
|
||||
// if we have a send forward op followed by a recv backward op, insert
|
||||
// WaitEvent and RecordEvent in between.
|
||||
Node* record_node = nullptr;
|
||||
Node* wait_node = nullptr;
|
||||
|
||||
|
|
@ -239,7 +237,7 @@ Status AddOrSkipRecordForwardWaitBackward(Graph& graph,
|
|||
{
|
||||
std::vector<NodeArg*> input_args;
|
||||
std::vector<NodeArg*> output_args;
|
||||
AddInputEvent(graph, "RecordEvent", true /* is_forward */, input_args, new_input_names);
|
||||
AddInputEvent(graph, "forward_recorded_event_id", input_args, new_input_names);
|
||||
|
||||
// Add send forward op's output as record op's input and output
|
||||
for (auto& output : forward_send->MutableOutputDefs()) {
|
||||
|
|
@ -248,23 +246,19 @@ Status AddOrSkipRecordForwardWaitBackward(Graph& graph,
|
|||
input_args.push_back(output);
|
||||
}
|
||||
|
||||
auto& new_node = graph.AddNode(graph.GenerateNodeName("RecordEvent"),
|
||||
"RecordEvent",
|
||||
"",
|
||||
input_args,
|
||||
output_args, /* output */
|
||||
{}, /* attribute */
|
||||
kMSDomain);
|
||||
record_node = &new_node;
|
||||
record_node = &CreateBottleneckNode(
|
||||
graph, "RecordEvent", "forward_record", "", nullptr,
|
||||
input_args, output_args);
|
||||
|
||||
forward_recorded_event_name = record_node->InputDefs()[0]->Name();
|
||||
forward_output_name = record_node->OutputDefs()[0]->Name();
|
||||
}
|
||||
|
||||
// Insert WaitEvent
|
||||
{
|
||||
std::vector<NodeArg*> input_args;
|
||||
std::vector<NodeArg*> output_args;
|
||||
AddInputEvent(graph, "WaitEvent", false /* is_forward */, input_args, new_input_names);
|
||||
AddInputEvent(graph, "backward_waited_event_id", input_args, new_input_names);
|
||||
|
||||
input_args.insert(std::end(input_args),
|
||||
std::begin(record_node->MutableOutputDefs()),
|
||||
|
|
@ -275,14 +269,9 @@ Status AddOrSkipRecordForwardWaitBackward(Graph& graph,
|
|||
output_args.push_back(&new_output);
|
||||
input = &new_output;
|
||||
|
||||
auto& new_node = graph.AddNode(graph.GenerateNodeName("WaitEvent"),
|
||||
"WaitEvent",
|
||||
"Backward pass",
|
||||
input_args,
|
||||
output_args, /* output */
|
||||
{}, /* attribute */
|
||||
kMSDomain);
|
||||
wait_node = &new_node;
|
||||
wait_node = &CreateBottleneckNode(
|
||||
graph, "WaitEvent", "backward_wait", "Backward pass", nullptr,
|
||||
input_args, output_args);
|
||||
|
||||
backward_waited_event_name = wait_node->InputDefs()[0]->Name();
|
||||
backward_output_name = wait_node->OutputDefs()[0]->Name();
|
||||
|
|
@ -297,12 +286,14 @@ void ReplaceNodeArgs(std::vector<Node*>& nodes,
|
|||
std::vector<NodeArg*>& new_node_args) {
|
||||
ORT_ENFORCE(node_args.size() == new_node_args.size());
|
||||
for (size_t i = 0; i < node_args.size(); ++i) {
|
||||
// At this iteration, we replace node_args[i] with
|
||||
// Iteration for node_args[i] and new_node_args[i].
|
||||
|
||||
ORT_ENFORCE(node_args[i]->Name() != new_node_args[i]->Name());
|
||||
ORT_ENFORCE(node_args[i]->Type() == new_node_args[i]->Type());
|
||||
|
||||
for (auto& node: nodes) {
|
||||
for (auto& node_arg: node->MutableInputDefs()) {
|
||||
// Only replace when node's input name matches node_args[i].
|
||||
if (node_arg->Name().compare(node_args[i]->Name()) != 0) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -346,10 +337,9 @@ std::string AddEventBeforeNode(
|
|||
auto event_node_arg = &CreateInt64NodeArg(graph, event_id_name);
|
||||
|
||||
// Create node which produces new_node_args from event ID and node_args.
|
||||
auto name = graph.GenerateNodeName(event_op_name);
|
||||
CreateBottleneckNode(graph,
|
||||
event_op_type,
|
||||
name,
|
||||
event_op_name,
|
||||
"",
|
||||
event_node_arg,
|
||||
node_args,
|
||||
|
|
@ -389,10 +379,9 @@ std::string AddEventAfterNode(
|
|||
auto event_node_arg = &CreateInt64NodeArg(graph, event_id_name);
|
||||
|
||||
// Create node which produces new_node_args from event ID and node_args.
|
||||
auto name = graph.GenerateNodeName(event_op_name);
|
||||
CreateBottleneckNode(graph,
|
||||
event_op_type,
|
||||
name,
|
||||
event_op_name,
|
||||
"",
|
||||
event_node_arg,
|
||||
node_args,
|
||||
|
|
@ -469,7 +458,52 @@ Status AddBackwardRecordBeforeSend(
|
|||
}
|
||||
}
|
||||
|
||||
// Insert WaitEvent and RecordEvent to the partition.
|
||||
// This function inserts WaitEvent's and RecordEvent's to the input graph for
|
||||
// controlling synchronization between (batch, pipeline stage)-pairs.
|
||||
//
|
||||
// The input graph is a pipeline's stage, which contains some Send's and Recv's.
|
||||
//
|
||||
// For diferent pipeline stages, they have different communication patterns as
|
||||
// shown below.
|
||||
//
|
||||
// 1. First stage:
|
||||
// FW -----------> Send ----------->
|
||||
// ------> Recv ---------> BW
|
||||
// 2. Middle stage:
|
||||
// Recv ---------> FW -----------> Send ----------->
|
||||
// ------> Recv ---------> BW -----------> Send
|
||||
// 3. Last stage:
|
||||
// Recv ---------> FW ----------------------------->
|
||||
// ----------------------> BW -----------> Send
|
||||
//
|
||||
// This function inserts some event operators and those patterns become
|
||||
//
|
||||
// 1. First stage:
|
||||
// Wait ---------> Wait -> FW -> Record -> Send -> Record ->
|
||||
// Wait -> Recv -> Wait -> BW -> Record ---------> Record
|
||||
// 2. Middle stage:
|
||||
// Wait -> Recv -> Wait -> FW -> Record -> Send -> Record ->
|
||||
// Wait -> Recv -> Wait -> BW -> Record -> Send -> Record
|
||||
// 3. Last stage:
|
||||
// Wait -> Recv -> Wait -> FW ----------------------------->
|
||||
// ----------------------> BW -> Record -> Send -> Record
|
||||
//
|
||||
// To explain the meaning of those operators, we take the middle stage's pattern
|
||||
// as an example:
|
||||
//
|
||||
// Wait-0 -> Recv -> Wait-1 -> FW -> Record-0 -> Send -> Record-1 ->
|
||||
// Wait-2 -> Recv -> Wait-3 -> BW -> Record-2 -> Send -> Record-3
|
||||
//
|
||||
// Their meanings are listed below.
|
||||
//
|
||||
// Wait-0: Wait until we can start reciving forward data.
|
||||
// Wait-1: Wait until we can start forward pass.
|
||||
// Record-0: Tell others that forward pass is done.
|
||||
// Record-1: Tell others that forward result has been passed to another stage.
|
||||
// Wait-2: Wait until we can start reciving backward data.
|
||||
// Wait-3: Wait until we can start backward bass.
|
||||
// Record-2: Tell others that backward pass is done.
|
||||
// Record-3: Tell others that backward result has been passed to another stage.
|
||||
Status TransformGraphForPipeline(
|
||||
Graph& graph,
|
||||
std::string& forward_waited_event_name,
|
||||
|
|
@ -508,26 +542,26 @@ Status TransformGraphForPipeline(
|
|||
}
|
||||
|
||||
// Names to added into this graph's input list.
|
||||
// Their value may be provides as "feeds" when calling session.Run(...).
|
||||
// Their values may be provides as "feeds" when calling session.Run(...).
|
||||
std::vector<std::string> new_input_names;
|
||||
// Names to added into this graph's output list.
|
||||
// Their value may be provides as "feeds" when calling session.Run(...).
|
||||
// Their values may be returned as "fetches" when calling session.Run(...).
|
||||
std::vector<std::string> new_output_names;
|
||||
|
||||
backward_record = AddRecordBackward(
|
||||
backward_record = AddBackwardRecord(
|
||||
graph,
|
||||
backward_send,
|
||||
new_input_names,
|
||||
new_output_names,
|
||||
backward_recorded_event_name,
|
||||
backward_record_output_name);
|
||||
forward_wait = AddWaitForward(
|
||||
forward_wait = AddForwardWait(
|
||||
graph,
|
||||
forward_recv,
|
||||
new_input_names,
|
||||
forward_waited_event_name,
|
||||
forward_wait_output_name);
|
||||
ORT_RETURN_IF_ERROR(AddOrSkipRecordForwardWaitBackward(
|
||||
ORT_RETURN_IF_ERROR(AddOrSkipForwardRecordBackwardWait(
|
||||
graph,
|
||||
forward_send,
|
||||
backward_recv,
|
||||
|
|
@ -552,19 +586,6 @@ Status TransformGraphForPipeline(
|
|||
// 3. Last stage:
|
||||
// Wait -> Recv ---------> FW ----------------------------->
|
||||
// ----------------------> BW -----------> Send -> Record
|
||||
//
|
||||
// After applying all transformations below, we will have
|
||||
// the following patterns.
|
||||
//
|
||||
// 1. First stage:
|
||||
// Wait ---------> Wait -> FW -> Record -> Send -> Record ->
|
||||
// Wait -> Recv -> Wait -> BW -> Record ---------> Record
|
||||
// 2. Middle stage:
|
||||
// Wait -> Recv -> Wait -> FW -> Record -> Send -> Record ->
|
||||
// Wait -> Recv -> Wait -> BW -> Record -> Send -> Record
|
||||
// 3. Last stage:
|
||||
// Wait -> Recv -> Wait -> FW ----------------------------->
|
||||
// ----------------------> BW -> Record -> Send -> Record
|
||||
const bool is_first_stage = !forward_recv && forward_send && backward_recv && !backward_send;
|
||||
const bool is_middle_stage = forward_recv && forward_send && backward_recv && backward_send;
|
||||
const bool is_last_stage = forward_recv && !forward_send && !backward_recv && backward_send;
|
||||
|
|
@ -572,7 +593,7 @@ Status TransformGraphForPipeline(
|
|||
// One and only one of is_first_stage, is_middle_stage, and is_last_stage can be true.
|
||||
const unsigned int stage_flag_sum = is_first_stage + is_middle_stage + is_last_stage;
|
||||
ORT_RETURN_IF_NOT(stage_flag_sum == 1u,
|
||||
"The processed graph should be classified into an stage, "
|
||||
"The processed graph should be classified into a stage, "
|
||||
"but we see more than one true's in the following statements. ",
|
||||
"Is first stage? ", is_first_stage, ". ",
|
||||
"Is middle stage? ", is_middle_stage, ". ",
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ class PipelineSchedule {
|
|||
// It equals to table_.size().
|
||||
int num_stages_;
|
||||
// Total number of batches scheduled in this pipeline.
|
||||
// It equals to table_[i].size(), for i = 0, ..., num_stages_.
|
||||
// It equals to table_[i].size(), for i = 0, ..., num_stages_ - 1.
|
||||
int num_batches_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1087,14 +1087,28 @@ void RetrieveSendRecvOperators(
|
|||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Send") {
|
||||
if (is_backward(node)) {
|
||||
// backward_send can only be assigned one valid pointer.
|
||||
// If it is assigned more than once, it means we have multiple
|
||||
// Send in backward pass and therefore our assumption doesn't hold.
|
||||
// This check ensure that only we only update *backward_send when
|
||||
// its value is NULL and guards our one-Recv assumption.
|
||||
ASSERT_TRUE(!(*backward_send));
|
||||
*backward_send = &node;
|
||||
} else {
|
||||
// Guard the uniqueness of Send in the forward pass by throwing
|
||||
// when *forward_send already carries a valid pointer.
|
||||
ASSERT_TRUE(!(*forward_send));
|
||||
*forward_send = &node;
|
||||
}
|
||||
} else if (node.OpType() == "Recv") {
|
||||
if (is_backward(node)) {
|
||||
// Guard the uniqueness of Recv in the backward pass by throwing
|
||||
// when *backward_recv already carries a valid pointer.
|
||||
ASSERT_TRUE(!(*backward_recv));
|
||||
*backward_recv = &node;
|
||||
} else {
|
||||
// Guard the uniqueness of Recv in the forwaard pass by throwing
|
||||
// when *forward_recv already carries a valid pointer.
|
||||
*forward_recv = &node;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue