diff --git a/onnxruntime/test/testdata/test_training_model_0.onnx b/onnxruntime/test/testdata/test_training_model_0.onnx new file mode 100644 index 0000000000..2991cb758c Binary files /dev/null and b/onnxruntime/test/testdata/test_training_model_0.onnx differ diff --git a/onnxruntime/test/testdata/test_training_model_1.onnx b/onnxruntime/test/testdata/test_training_model_1.onnx new file mode 100644 index 0000000000..d7041e03bd Binary files /dev/null and b/onnxruntime/test/testdata/test_training_model_1.onnx differ diff --git a/onnxruntime/test/testdata/test_training_model_2.onnx b/onnxruntime/test/testdata/test_training_model_2.onnx new file mode 100644 index 0000000000..d0958df31a Binary files /dev/null and b/onnxruntime/test/testdata/test_training_model_2.onnx differ diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 5c58e70e8b..609ff30304 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1045,7 +1045,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetSendGradient) { } return std::vector{ - NodeDef("Recv", + NodeDef(OpDef{"Recv", kMSDomain, 1}, {O(0), I(1)}, // {Signal, Remote} out_args, SrcNodeAttributes())}; @@ -1057,14 +1057,14 @@ IMPLEMENT_GRADIENT_BUILDER(GetRecvGradient) { std::vector in_args; in_args.push_back(O(0)); // Signal - in_args.push_back(I(0)); // Remote + in_args.push_back(I(1)); // Remote for (int i = 1; i < GetSrcNodeOutputSize(); ++i) { in_args.push_back(GO(i)); // Data } return std::vector{ - NodeDef("Send", + NodeDef(OpDef{"Send", kMSDomain, 1}, in_args, {GI(0)}, // Signal SrcNodeAttributes())}; diff --git a/orttraining/orttraining/core/graph/gradient_schema_defs.cc b/orttraining/orttraining/core/graph/gradient_schema_defs.cc index adff2ed325..9249dd8410 100644 --- a/orttraining/orttraining/core/graph/gradient_schema_defs.cc +++ b/orttraining/orttraining/core/graph/gradient_schema_defs.cc @@ -1670,7 +1670,7 @@ Return true if all elements are true and false otherwise. "Allow inputs and outputs to be any kind of tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { if (ctx.getNumInputs() < ctx.getNumOutputs() + 1) - fail_shape_inference("WaitEvent must have at least (num_outputs + 1) inputs."); + fail_shape_inference("RecordEvent must have at least (num_outputs + 1) inputs."); // note: if num_input > num_output + 1, // the additional inputs (idx >= num_ouput + 1) are regarded as dependencies diff --git a/orttraining/orttraining/core/graph/pipeline_transformer.cc b/orttraining/orttraining/core/graph/pipeline_transformer.cc new file mode 100644 index 0000000000..39700c6e9b --- /dev/null +++ b/orttraining/orttraining/core/graph/pipeline_transformer.cc @@ -0,0 +1,279 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/core/graph/pipeline_transformer.h" + +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace training { + +void GetPipelineSendOutput(const Graph& graph, std::string& loss_name) { + for (auto& node : graph.Nodes()) { + if (!node.OpType().compare("Send")) { + // send op should always have an output, which is the OutputSignal. + loss_name = node.OutputDefs()[0]->Name(); + return; + } + } +} + +bool IsBackward(Node& node) { + return (node.Description() == "Backward pass"); +} + +void AddInputEvent(Graph& graph, const std::string& op_name, + bool is_forward, + std::vector& input_args, + std::vector& new_input_names) { + ONNX_NAMESPACE::TypeProto event_type_proto; + event_type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + + auto event_id_name = graph.GenerateNodeArgName(op_name + (is_forward ? "_fw" : "_bw") + "_event_id"); + auto& event_id = graph.GetOrCreateNodeArg(event_id_name, &event_type_proto); + new_input_names.push_back(event_id_name); + + input_args.push_back(&event_id); +} + +// gradient graph can contain some dangling leaf nodes. Add them all to WaitEvent +// backward node's input. +void FindLeafNodes(Graph& graph, std::vector& input_args) { + for (auto& node : graph.Nodes()) { + if (!IsBackward(node)) { + // only check backward node + continue; + } + bool find_consumer_nodes = false; + std::vector& outputs = node.MutableOutputDefs(); + for (auto& output : outputs) { + std::vector consumer_nodes = graph.GetConsumerNodes(output->Name()); + if (consumer_nodes.size() > 0) { + find_consumer_nodes = true; + break; + } + } + if (!find_consumer_nodes && outputs.size() > 0) { + input_args.push_back(outputs[0]); + } + } +}; + +NodeArg& CreateNodeArg(Graph& graph, const NodeArg* base_arg) { + const auto& new_name = graph.GenerateNodeArgName(base_arg->Name()); + ONNX_NAMESPACE::TypeProto type_proto(*(base_arg->TypeAsProto())); + if (graph.GetNodeArg(new_name) != nullptr) { + ORT_THROW("Node with name ", new_name, " already exists."); + } + return graph.GetOrCreateNodeArg(new_name, &type_proto); +} + +Status AddRecordBackward(Graph& graph, + Node* send_bw, + std::vector& new_input_names, + std::vector& new_output_names) { + std::vector input_args; + AddInputEvent(graph, "RecordEvent", false /* is_forward */, input_args, new_input_names); + std::vector output_args{}; + + if (send_bw) { + // if we have send op in backward pass (at the end of the graph), we make sure the RecordEvent happens + // after that send by adding Send's outputs to RecordEvent's input list. + input_args.insert(std::end(input_args), + std::begin(send_bw->MutableOutputDefs()), + std::end(send_bw->MutableOutputDefs())); + } + FindLeafNodes(graph, input_args); + + // Optimizer will be added after applying pipeline transformer. To support partial graph evaluation, + // the added Record backward op will have its first passthrough input as output. + ORT_RETURN_IF_NOT(input_args.size() >= 2, "RecordEvent backward op at least have two inputs.") + auto& new_output = CreateNodeArg(graph, input_args[1]); // the first input is signal, not passing through + output_args.push_back(&new_output); + new_output_names.push_back(new_output.Name()); + + graph.AddNode(graph.GenerateNodeName("RecordEvent"), + "RecordEvent", + "Backward pass", + input_args, + output_args, + nullptr, + kMSDomain); + return Status::OK(); +} + +Status AddWaitForward(Graph& graph, Node* /* recv_fw */, std::vector& new_input_names) { + // Append old_input to input_args and return its pass-through value. Note that + // input_args and output_args are Wait's inputs and outputs, respectively. + auto update_wait_input_output = [&](NodeArg* old_input, + std::vector& input_args, + std::vector& output_args) -> NodeArg& { + input_args.push_back(old_input); + + const auto& new_name = graph.GenerateNodeArgName(old_input->Name()); + ONNX_NAMESPACE::TypeProto type_proto(*(old_input->TypeAsProto())); + + auto& wait_output = graph.GetOrCreateNodeArg(new_name, &type_proto); + output_args.push_back(&wait_output); + + return wait_output; + }; + + std::vector input_args; + std::vector output_args; + AddInputEvent(graph, "WaitEvent", true /* is_forward */, input_args, new_input_names); + const std::vector& graph_inputs = graph.GetInputsIncludingInitializers(); + + if (graph_inputs.size() == 0){ + ORT_THROW("Graph ", graph.Name(), " doesn't have any inputs."); + } + + for (auto& input_arg : graph_inputs) { + NodeArg* mutable_input = graph.GetNodeArg(input_arg->Name()); + auto& wait_output = update_wait_input_output(mutable_input, input_args, output_args); + std::vector nodes = graph.GetMutableConsumerNodes(input_arg->Name()); + for (auto& consumer_node : nodes) { + for (auto& i : consumer_node->MutableInputDefs()) { + if (i->Name() == input_arg->Name()) { + // if the node is fed by input, re-direct it to be fed by WaitEvent's output. + i = &wait_output; + } + } + } + } + graph.AddNode(graph.GenerateNodeName("WaitEvent"), + "WaitEvent", + "", + input_args, + output_args, + nullptr, + kMSDomain); + + return Status::OK(); +} + +Status AddOrSkipRecordForwardWaitBackward(Graph& graph, Node* send_fw, Node* recv_bw, std::vector& new_input_names) { + if (!send_fw != !recv_bw){ + ORT_THROW("Graph requires either having both send forward node " + "and recv backword node, or none of them. Currently the graph " + "has send forward: ", send_fw, " and recv backward: ", recv_bw); + } + + if (!send_fw && !recv_bw){ + // Last partition doesn't have send forwrad and recv backward. No insert needed. + return Status::OK(); + } + + // if we have a send forward op followed by a recv backward op, insert WaitEvent and RecordEvent in between. + Node* record_node = nullptr; + Node* wait_node = nullptr; + + // Insert RecordEvent + { + std::vector input_args; + std::vector output_args; + AddInputEvent(graph, "RecordEvent", true /* is_forward */, input_args, new_input_names); + + // Add send forward op's output as record op's input and output + for (auto& output : send_fw->MutableOutputDefs()) { + auto& new_output = CreateNodeArg(graph, output); + output_args.push_back(&new_output); + input_args.push_back(output); + } + + auto& new_node = graph.AddNode(graph.GenerateNodeName("RecordEvent"), + "RecordEvent", + "", + input_args, + output_args, /* output */ + {}, /* attribute */ + kMSDomain); + record_node = &new_node; + } + // Insert WaitEvent + { + std::vector input_args; + std::vector output_args; + AddInputEvent(graph, "WaitEvent", false /* is_forward */, input_args, new_input_names); + + input_args.insert(std::end(input_args), + std::begin(record_node->MutableOutputDefs()), + std::end(record_node->MutableOutputDefs())); + + auto& input = recv_bw->MutableInputDefs()[0]; + auto& new_output = CreateNodeArg(graph, input); + output_args.push_back(&new_output); + input = &new_output; + + auto& new_node = graph.AddNode(graph.GenerateNodeName("WaitEvent"), + "WaitEvent", + "Backward pass", + input_args, + output_args, /* output */ + {}, /* attribute */ + kMSDomain); + wait_node = &new_node; + ORT_UNUSED_PARAMETER(wait_node); + } + + return Status::OK(); +} + +Status TransformGraphForPipeline(Graph& graph) { + // insert WaitEvent and RecordEvent to the partition + Node* send_fw{nullptr}; + Node* send_bw{nullptr}; + Node* recv_fw{nullptr}; + Node* recv_bw{nullptr}; + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Send") { + if (IsBackward(node)) { + send_bw = &node; + } else { + send_fw = &node; + } + } else if (node.OpType() == "Recv") { + if (IsBackward(node)) { + recv_bw = &node; + } else { + recv_fw = &node; + } + } + } + + std::vector new_input_names; + std::vector new_output_names; + + ORT_RETURN_IF_ERROR(AddRecordBackward(graph, send_bw, new_input_names, new_output_names)); + ORT_RETURN_IF_ERROR(AddWaitForward(graph, recv_fw, new_input_names)); + ORT_RETURN_IF_ERROR(AddOrSkipRecordForwardWaitBackward(graph, send_fw, recv_bw, new_input_names)); + + auto fill_node_args = [&](const Graph& graph, + const std::vector& existed_node_args, + std::vector& new_node_arg_names, + std::vector& merged_node_args) { + merged_node_args.insert(merged_node_args.end(), existed_node_args.begin(), existed_node_args.end()); + for (auto& name : new_node_arg_names) { + merged_node_args.push_back(graph.GetNodeArg(name)); + } + }; + + const std::vector& graph_inputs = graph.GetInputsIncludingInitializers(); + std::vector inputs_args_sets; + inputs_args_sets.reserve(graph_inputs.size() + new_input_names.size()); + fill_node_args(graph, graph_inputs, new_input_names, inputs_args_sets); + + const std::vector& graph_outputs = graph.GetOutputs(); + std::vector outputs_args_sets; + outputs_args_sets.reserve(graph_outputs.size() + new_output_names.size()); + fill_node_args(graph, graph_outputs, new_output_names, outputs_args_sets); + + graph.SetInputs(inputs_args_sets); + graph.SetOutputs(outputs_args_sets); + graph.SetGraphResolveNeeded(); + graph.SetGraphProtoSyncNeeded(); + return graph.Resolve(); +} + +} // namespace training +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/pipeline_transformer.h b/orttraining/orttraining/core/graph/pipeline_transformer.h new file mode 100644 index 0000000000..11519024e1 --- /dev/null +++ b/orttraining/orttraining/core/graph/pipeline_transformer.h @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/graph/graph.h" + +namespace onnxruntime { +namespace training { + +void GetPipelineSendOutput(const Graph& graph, std::string& loss_name); +common::Status TransformGraphForPipeline(Graph& graph); + +} // namespace training +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index 57e00e5780..117ffcf262 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -15,6 +15,7 @@ #include "core/optimizer/rule_based_graph_transformer.h" #include "orttraining/core/graph/mixed_precision_transformer.h" #include "orttraining/core/graph/tensorboard_transformer.h" +#include "orttraining/core/graph/pipeline_transformer.h" #include "orttraining/core/graph/gradient_builder_base.h" //Gist Encoding @@ -128,15 +129,23 @@ Status TrainingSession::ConfigureForTraining( is_mixed_precision_enabled_ = config.mixed_precision_config.has_value(); std::string loss_name{}; - const optional loss_function_info = - config.loss_function_config.has_value() - ? config.loss_function_config.value().loss_function_info - : optional{}; optional loss_scale_input_name = - is_mixed_precision_enabled_ ? optional{""} : optional{}; - ORT_RETURN_IF_ERROR(ConfigureLossFunction( - config.loss_name, loss_function_info, - loss_scale_input_name.has_value() ? &loss_scale_input_name.value() : nullptr, loss_name)); + is_mixed_precision_enabled_ ? optional{""} : optional{}; + if (config.use_pipeline) { + // if use pipeline, first check if model contains send op. If it does, set the + // send node's output as the start tensor to build gradient graph + GetPipelineSendOutput(model_->MainGraph(), loss_name); + } + if (loss_name.empty()) { + const optional loss_function_info = + config.loss_function_config.has_value() + ? config.loss_function_config.value().loss_function_info + : optional{}; + ORT_RETURN_IF_ERROR(ConfigureLossFunction( + config.loss_name, loss_function_info, + loss_scale_input_name.has_value() ? &loss_scale_input_name.value() : nullptr, loss_name)); + } + ORT_ENFORCE( !loss_scale_input_name.has_value() || !loss_scale_input_name.value().empty(), "loss_scale_input_name should not be set to an empty string."); @@ -170,7 +179,6 @@ Status TrainingSession::ConfigureForTraining( << weight_names_stream.str(); } - // add gradient graph ORT_RETURN_IF_ERROR(BuildGradientGraph( weight_names_to_train, loss_name, config.set_gradients_as_graph_outputs)); @@ -182,12 +190,30 @@ Status TrainingSession::ConfigureForTraining( weight_names_to_train, mixed_precision_config.use_fp16_initializers, fp32_weight_name_to_fp16_node_arg)); } + if (config.use_pipeline) { + ORT_RETURN_IF_ERROR(InsertPipelineOps()); + } + + // All non-float tensors are not trainable. Remove those weights. + // TODO: this is a temp workaround for removing rank tensor before adding optimizer. + // Re-visit after we port logic for model splitting and hence know the rank tensor name. + for (auto it = weights_to_train_.begin(); it != weights_to_train_.end();) { + const auto* node_arg = model_->MainGraph().GetNodeArg(*it); + ORT_RETURN_IF_NOT(node_arg, "Failed to get NodeArg with name ", *it); + if (node_arg->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + it = weights_to_train_.erase(it); + } + else{ + ++it; + } + } + // add optimizer or gradient accumulation if (config.optimizer_config.has_value()) { OptimizerGraphConfig opt_graph_config{}; std::unordered_map opt_node_configs{}; ORT_RETURN_IF_ERROR(SetupOptimizerParams( - weight_names_to_train, fp32_weight_name_to_fp16_node_arg, + weights_to_train_, fp32_weight_name_to_fp16_node_arg, loss_scale_input_name, config, opt_graph_config, opt_node_configs)); TrainingConfigurationResult::OptimizerConfigurationResult optimizer_config_result{}; @@ -198,7 +224,7 @@ Status TrainingSession::ConfigureForTraining( config_result.opt_config_result = optimizer_config_result; } else { if (config.gradient_accumulation_steps > 1) { - ORT_RETURN_IF_ERROR(BuildAccumulationNode(weight_names_to_train)); + ORT_RETURN_IF_ERROR(BuildAccumulationNode(weights_to_train_)); } } @@ -439,6 +465,11 @@ Status TrainingSession::AddTensorboard(const std::string& summary_name, return DoPostLoadProcessing(*model_); } +Status TrainingSession::InsertPipelineOps() { + ORT_RETURN_IF_ERROR(TransformGraphForPipeline(model_->MainGraph())); + return DoPostLoadProcessing(*model_); +} + Status TrainingSession::ConfigureLossFunction( const optional& external_loss_name, const optional& loss_function_info, diff --git a/orttraining/orttraining/core/session/training_session.h b/orttraining/orttraining/core/session/training_session.h index bf06eb43d4..91725cadc1 100644 --- a/orttraining/orttraining/core/session/training_session.h +++ b/orttraining/orttraining/core/session/training_session.h @@ -132,6 +132,9 @@ class TrainingSession : public InferenceSession { // The optimizer configuration. // If not provided, no optimizer is added. optional optimizer_config{}; + + // Whether to use pipeline in training. + bool use_pipeline{false}; }; /** @@ -262,6 +265,7 @@ class TrainingSession : public InferenceSession { const std::vector& norm_nodes, const bool dump_convergence_metrics); + common::Status InsertPipelineOps(); common::Status ApplyTransformationsToMainGraph(); /** configure initial transformers for training */ diff --git a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc index 1c1a63bbe5..7d6b015219 100644 --- a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc +++ b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc @@ -997,6 +997,106 @@ class PipelineBatchPlanner { } }; +// verify pipeline config can load and gradient graph can construct. +TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) { + PathString filename_base = ORT_TSTR("testdata/test_training_model_"); + + auto load_gradient_graph = [](int stageIdx, PathString& input_file, PathString& output_file) { + auto config = MakeBasicTrainingConfig(); + + config.use_pipeline = true; + + PathString backprop_model_file; + ASSERT_STATUS_OK(BuildBackPropGraph(input_file, config, backprop_model_file)); + + std::shared_ptr model; + ASSERT_TRUE(Model::Load(backprop_model_file, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + + Graph& graph = model->MainGraph(); + auto is_backward = [](Node& node) { + return (node.Description() == "Backward pass"); + }; + // check for wait/record node + Node* wait_fw{nullptr}; + Node* wait_bw{nullptr}; + Node* record_fw{nullptr}; + Node* record_bw{nullptr}; + for (auto& node : graph.Nodes()) { + if (node.OpType() == "WaitEvent") { + if (is_backward(node)) { + wait_bw = &node; + } else { + wait_fw = &node; + } + } else if (node.OpType() == "RecordEvent") { + if (is_backward(node)) { + record_bw = &node; + } else { + record_fw = &node; + } + } + } + // every partition should have wait forward and record backward + ASSERT_TRUE(wait_fw && record_bw); + if (stageIdx == 2) { + // the last partition can perform back prop right away. It won't have record + // forward and wait backward + ASSERT_TRUE(!record_fw && !wait_bw); + } else { + ASSERT_TRUE(record_fw && wait_bw); + } + + // check for send/recv node + Node* send_fw{nullptr}; + Node* send_bw{nullptr}; + Node* recv_fw{nullptr}; + Node* recv_bw{nullptr}; + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Send") { + if (is_backward(node)) { + send_bw = &node; + } else { + send_fw = &node; + } + } else if (node.OpType() == "Recv") { + if (is_backward(node)) { + recv_bw = &node; + } else { + recv_fw = &node; + } + } + } + // except the last partion, each partition should have send forward and recv backward + if (stageIdx == 0 || stageIdx == 1) { + ASSERT_TRUE(send_fw && recv_bw); + } else { + ASSERT_TRUE(!send_fw && !recv_bw); + } + // except the first partion, each partition should have recv forward and send backward + if (stageIdx == 1 || stageIdx == 2) { + ASSERT_TRUE(recv_fw && send_bw); + } else { + ASSERT_TRUE(!recv_fw && !send_bw); + } + + auto mp = model->ToProto(); + std::ofstream ofs(output_file, std::ofstream::binary); + mp.SerializeToOstream(&ofs); + ofs.close(); + }; + + for (int i = 0; i < 3; ++i) { +#ifdef _WIN32 + auto surfix = std::to_wstring(i); +#else + auto surfix = std::to_string(i); +#endif + PathString input_file = filename_base + surfix + ORT_TSTR(".onnx"); + PathString output_file = filename_base + surfix + ORT_TSTR("_back.onnx"); + load_gradient_graph(i, input_file, output_file); + } +} + TEST(GradientGraphBuilderTest, TrainingSession_WithPipeline) { auto config = MakeBasicTrainingConfig(); //config.set_gradients_as_graph_outputs = true; diff --git a/orttraining/tools/scripts/pipeline_model_split.py b/orttraining/tools/scripts/pipeline_model_split.py index d0385edeac..008e626e32 100644 --- a/orttraining/tools/scripts/pipeline_model_split.py +++ b/orttraining/tools/scripts/pipeline_model_split.py @@ -4,7 +4,6 @@ import onnx from onnx import helper from onnx import TensorProto from onnx import OperatorSetIdProto - # Edge that needs to be cut for the split. # If the edge is feeding into more than one nodes, and not all the nodes belong to the same cut, # specify those consuming nodes that need to be cut @@ -29,31 +28,6 @@ def split_graph(model, split_edge_groups): new_send_nodes = [] new_recv_nodes = [] - # Add wait for initial inputs. This needs to be done first before new inputs - # are introduced from split - initializer_lists = [a.name for a in model.graph.initializer] - input_tensors = [ - value.name for value in model.graph.input if value.name not in initializer_lists] - - input_wait_signal = model.graph.input.add() - input_wait_signal.CopyFrom(helper.make_tensor_value_info( - 'input_wait_signal', onnx.TensorProto.INT64, None)) - - input_wait = model.graph.node.add() - input_wait.CopyFrom(helper.make_node( - 'WaitEvent', - inputs=['input_wait_signal'], - outputs=[], - domain=ms_domain)) - - for i in input_tensors: - for node in model.graph.node: - for j in range(len(node.input)): - if node.input[j] == i: - node.input[j] = i + '_sync' - - input_wait.input.extend(input_tensors) - input_wait.output.extend([i + '_sync' for i in input_tensors]) for cut_index in range(len(split_edge_groups)): edgeIds = split_edge_groups[cut_index] @@ -62,7 +36,7 @@ def split_graph(model, split_edge_groups): upstream_nodes = [] upstream_nodes_output_index = [] output_shapes = [] - + element_types = [] for id in edgeIds: for node in model.graph.node: if len(node.output) >= 1: @@ -70,25 +44,43 @@ def split_graph(model, split_edge_groups): if j == id: upstream_nodes.append(node) upstream_nodes_output_index.append(i) - for info in model.graph.value_info: - if info.name == id: - output_shapes.append(info.type) - - record_signal = model.graph.input.add() - record_signal.CopyFrom(helper.make_tensor_value_info( - 'record_input_signal' + str(cut_index), onnx.TensorProto.INT64, None)) - - wait_signal = model.graph.input.add() - wait_signal.CopyFrom(helper.make_tensor_value_info( - 'wait_input_signal' + str(cut_index), onnx.TensorProto.INT64, None)) + # assuming all tensors are of type float + element_types.append(1) + for info in model.graph.value_info: + if info.name == id: + output_shapes.append(info.type) + send_input_signal_name = 'send_input_signal' + str(cut_index) send_signal = model.graph.input.add() send_signal.CopyFrom(helper.make_tensor_value_info( - 'send_input_signal' + str(cut_index), onnx.TensorProto.BOOL, None)) + send_input_signal_name, onnx.TensorProto.BOOL, None)) + send_signal = helper.make_tensor( + send_input_signal_name, TensorProto.BOOL, (), (True,)) + model.graph.initializer.extend([send_signal]) + recv_input_signal_name = 'recv_input_signal' + str(cut_index) recv_signal = model.graph.input.add() recv_signal.CopyFrom(helper.make_tensor_value_info( - 'recv_input_signal' + str(cut_index), onnx.TensorProto.BOOL, None)) + recv_input_signal_name, onnx.TensorProto.BOOL, None)) + recv_signal = helper.make_tensor( + recv_input_signal_name, TensorProto.BOOL, (), (True,)) + model.graph.initializer.extend([recv_signal]) + + send_dst_rank_name = 'send_dst_rank' + str(cut_index) + send_dst_rank = model.graph.input.add() + send_dst_rank.CopyFrom(helper.make_tensor_value_info( + send_dst_rank_name, onnx.TensorProto.INT64, None)) + send_dst_rank = helper.make_tensor( + send_dst_rank_name, TensorProto.INT64, (), (cut_index + 1,)) + model.graph.initializer.extend([send_dst_rank]) + + recv_src_rank_name = 'recv_src_rank' + str(cut_index) + recv_src_rank = model.graph.input.add() + recv_src_rank.CopyFrom(helper.make_tensor_value_info( + recv_src_rank_name, onnx.TensorProto.INT64, None)) + recv_src_rank = helper.make_tensor( + recv_src_rank_name, TensorProto.INT64, (), (cut_index,)) + model.graph.initializer.extend([recv_src_rank]) # output signal from send after cut send_output_signal = model.graph.output.add() @@ -103,41 +95,23 @@ def split_graph(model, split_edge_groups): new_send = model.graph.node.add() new_send.CopyFrom(helper.make_node( 'Send', - inputs=['send_input_signal' + str(cut_index)], + inputs=[send_input_signal_name, send_dst_rank_name], outputs=['send_output_signal' + str(cut_index)], tag=0, - src=cut_index, - dst=cut_index + 1, domain=ms_domain, - element_type=7, # assuming all tensors are of type float + element_types=element_types, name='send')) new_receive = model.graph.node.add() new_receive.CopyFrom(helper.make_node( 'Recv', - inputs=['recv_input_signal' + str(cut_index)], + inputs=[recv_input_signal_name, recv_src_rank_name], outputs=['receive_output_signal' + str(cut_index)], - tag=1, - src=cut_index, - dst=cut_index + 1, + tag=0, domain=ms_domain, - element_type=7, # assuming all tensors are of type float + element_types=element_types, name='receive')) - new_wait = model.graph.node.add() - new_wait.CopyFrom(helper.make_node( - 'WaitEvent', - inputs=['wait_input_signal' + str(cut_index)], - outputs=[], - domain=ms_domain)) - - new_record = model.graph.node.add() - new_record.CopyFrom(helper.make_node( - 'RecordEvent', - inputs=['record_input_signal' + str(cut_index)], - outputs=[], - domain=ms_domain)) - for i in range(len(upstream_nodes)): n = upstream_nodes[i] idx = upstream_nodes_output_index[i] @@ -155,24 +129,16 @@ def split_graph(model, split_edge_groups): '_recv' + str(cut_index) add_expand_type(model, new_receive_output_name, output_type) - new_wait_output_name = output_edge_name + '_wait' + str(cut_index) - add_expand_type(model, new_wait_output_name, output_type) - # the order of data flow is: node-output -> record -> send -> recv -> wait -> node-input - new_record.input.extend([output_edge_name]) - new_record.output.extend([new_send_input_name]) - new_send.input.extend([new_send_input_name]) + new_send.input.extend([output_edge_name]) new_receive.output.extend([new_receive_output_name]) - new_wait.input.extend([new_receive_output_name]) - new_wait.output.extend([new_wait_output_name]) - for output_node in output_nodes: for i in range(len(output_node.input)): for edgeId in edgeIds: if output_node.input[i] == edgeId: - output_node.input[i] = new_wait_output_name + output_node.input[i] = new_receive_output_name new_send_nodes.append(new_send) new_recv_nodes.append(new_receive) @@ -236,9 +202,50 @@ def add_identity(model, cuttingEdge, newEdgeIdName): if output_nodes[i].input[j] == edgeId: output_nodes[i].input[j] = newEdgeIdName - return newEdgeIdName + return new_identity +def insert_identity(model, all_cut_inputs): + count = 0 + updated_edges = {} + new_added_identity = [] + split_edge_groups = [] + need_shape_inference = False + # Sweep the cut edge to see if there are edges feeding into nodes from two sub-graphs. If so, + # insert identity node after those edges with a new ID to distinguish the rest. + for cut_input in all_cut_inputs: + split_edges = [] + for i in cut_input: + if i.consumingNodes: + # if this edge has previously been modified, update its edgeId before inserting new identity + if i.edgeId in updated_edges: + i.edgeId = updated_edges[i.edgeId] + + new_edge_name = 'identity_output_' + str(count) + new_added_identity.append( + add_identity(model, i, new_edge_name)) + count += 1 + split_edges.append(new_edge_name) + updated_edges[i.edgeId] = new_edge_name + need_shape_inference = True + else: + split_edges.append(i.edgeId) + split_edge_groups.append(split_edges) + return split_edge_groups, new_added_identity, need_shape_inference + +# after the graph is split, remove the added identity node because identity op is not registered in gradient builder. + + +def remove_identity(model, new_added_identity): + for node in new_added_identity: + assert node.op_type == 'Identity' + output_nodes = [ + n for n in model.graph.node if node.output[0] in n.input] + for output_node in output_nodes: + for i in range(len(output_node.input)): + if output_node.input[i] == node.output[0]: + output_node.input[i] = node.input[0] + def find_all_connected_nodes(model, node): nodes0, inputs = find_all_input_nodes(model, node) nodes1, outputs = find_all_output_nodes(model, node) @@ -251,15 +258,36 @@ def get_index(node_list, node): found = [i for i, n in enumerate(node_list) if n == node] return found[0] if found else None +def get_identity_index_for_deleting(node_list, node): + for i, n in enumerate(node_list): + # The node's input name has been changed during send/recv insertion, + # but it is sufficient to just compare the type and outputs. + if (n.op_type == 'Identity' and n.output == node.output): + return i + return None + # traverse the graph, group connected nodes and generate subgraph -def generate_subgraph(model, start_nodes): +def generate_subgraph(model, start_nodes, identity_node_list): subgraphs = [] main_graph = onnx.ModelProto() main_graph.CopyFrom(model) + # remove added identity node before copy to subgraph + identity_node_index = [] + for n in identity_node_list: + identity_node_index.append(get_identity_index_for_deleting(main_graph.graph.node, n)) + identity_node_index.sort(reverse=True) + + for i in reversed(range(len(main_graph.graph.node))): + try: + if i in identity_node_index: + del main_graph.graph.node[i] + except: + print("error deleting identity node", i) + all_visited_nodes = [] model_count = len(start_nodes) for start in reversed(start_nodes): @@ -362,29 +390,8 @@ def main(): output_model_names = [os.path.splitext(input_model_name)[0] + '_' + str(i) + '.onnx' for i in range(stage_count)] - split_edge_groups = [] - count = 0 - updated_edges = {} - need_shape_inference = False - # Sweep the cut edge to see if there are edges feeding into nodes from two sub-graphs. If so, - # insert identity node after those edges with a new ID to distinguish the rest. - for cut_input in all_cut_inputs: - split_edges = [] - for i in cut_input: - if i.consumingNodes: - # if this edge has previously been modified, update its edgeId before inserting new identity - if i.edgeId in updated_edges: - i.edgeId = updated_edges[i.edgeId] + split_edge_groups, new_identity, need_shape_inference = insert_identity(model, all_cut_inputs) - new_edge_name = 'identity_output_' + str(count) - add_identity(model, i, new_edge_name) - count += 1 - split_edges.append(new_edge_name) - updated_edges[i.edgeId] = new_edge_name - need_shape_inference = True - else: - split_edges.append(i.edgeId) - split_edge_groups.append(split_edges) # new edge is being added, need to re-inference shape if need_shape_inference: @@ -392,11 +399,13 @@ def main(): # after all need-to-be-cut edges identified, split the graph new_sends, new_receives = split_graph(model, split_edge_groups) - sub_graphs = generate_subgraph(model, new_receives) + remove_identity(model, new_identity) + sub_graphs = generate_subgraph(model, new_receives, new_identity) for i in range(stage_count): sub_graphs[i] = onnx.shape_inference.infer_shapes(sub_graphs[i]) onnx.save(sub_graphs[i], output_model_names[i]) + print("save to file: ", output_model_names[i]) if __name__ == "__main__":