mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Add pipeline transformer for wait/record node (#3513)
* pipeline transformer * clean up * address feedback * add record/wait for first stage and updated split script * address feedback * make recv/send signal as initializer * merge * address feedback * unify input and initializer * address feedback and bug fix * minor fix * windows build * fix
This commit is contained in:
parent
6136fd0789
commit
f1ba9aaf34
11 changed files with 551 additions and 113 deletions
BIN
onnxruntime/test/testdata/test_training_model_0.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/test_training_model_0.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/test_training_model_1.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/test_training_model_1.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/test_training_model_2.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/test_training_model_2.onnx
vendored
Normal file
Binary file not shown.
|
|
@ -1045,7 +1045,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetSendGradient) {
|
|||
}
|
||||
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef("Recv",
|
||||
NodeDef(OpDef{"Recv", kMSDomain, 1},
|
||||
{O(0), I(1)}, // {Signal, Remote}
|
||||
out_args,
|
||||
SrcNodeAttributes())};
|
||||
|
|
@ -1057,14 +1057,14 @@ IMPLEMENT_GRADIENT_BUILDER(GetRecvGradient) {
|
|||
|
||||
std::vector<ArgDef> in_args;
|
||||
in_args.push_back(O(0)); // Signal
|
||||
in_args.push_back(I(0)); // Remote
|
||||
in_args.push_back(I(1)); // Remote
|
||||
|
||||
for (int i = 1; i < GetSrcNodeOutputSize(); ++i) {
|
||||
in_args.push_back(GO(i)); // Data
|
||||
}
|
||||
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef("Send",
|
||||
NodeDef(OpDef{"Send", kMSDomain, 1},
|
||||
in_args,
|
||||
{GI(0)}, // Signal
|
||||
SrcNodeAttributes())};
|
||||
|
|
|
|||
|
|
@ -1670,7 +1670,7 @@ Return true if all elements are true and false otherwise.
|
|||
"Allow inputs and outputs to be any kind of tensor.")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
if (ctx.getNumInputs() < ctx.getNumOutputs() + 1)
|
||||
fail_shape_inference("WaitEvent must have at least (num_outputs + 1) inputs.");
|
||||
fail_shape_inference("RecordEvent must have at least (num_outputs + 1) inputs.");
|
||||
|
||||
// note: if num_input > num_output + 1,
|
||||
// the additional inputs (idx >= num_ouput + 1) are regarded as dependencies
|
||||
|
|
|
|||
279
orttraining/orttraining/core/graph/pipeline_transformer.cc
Normal file
279
orttraining/orttraining/core/graph/pipeline_transformer.cc
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/core/graph/pipeline_transformer.h"
|
||||
|
||||
using namespace onnxruntime::common;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
||||
void GetPipelineSendOutput(const Graph& graph, std::string& loss_name) {
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (!node.OpType().compare("Send")) {
|
||||
// send op should always have an output, which is the OutputSignal.
|
||||
loss_name = node.OutputDefs()[0]->Name();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool IsBackward(Node& node) {
|
||||
return (node.Description() == "Backward pass");
|
||||
}
|
||||
|
||||
void AddInputEvent(Graph& graph, const std::string& op_name,
|
||||
bool is_forward,
|
||||
std::vector<NodeArg*>& input_args,
|
||||
std::vector<std::string>& new_input_names) {
|
||||
ONNX_NAMESPACE::TypeProto event_type_proto;
|
||||
event_type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
|
||||
auto event_id_name = graph.GenerateNodeArgName(op_name + (is_forward ? "_fw" : "_bw") + "_event_id");
|
||||
auto& event_id = graph.GetOrCreateNodeArg(event_id_name, &event_type_proto);
|
||||
new_input_names.push_back(event_id_name);
|
||||
|
||||
input_args.push_back(&event_id);
|
||||
}
|
||||
|
||||
// gradient graph can contain some dangling leaf nodes. Add them all to WaitEvent
|
||||
// backward node's input.
|
||||
void FindLeafNodes(Graph& graph, std::vector<NodeArg*>& input_args) {
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (!IsBackward(node)) {
|
||||
// only check backward node
|
||||
continue;
|
||||
}
|
||||
bool find_consumer_nodes = false;
|
||||
std::vector<NodeArg*>& outputs = node.MutableOutputDefs();
|
||||
for (auto& output : outputs) {
|
||||
std::vector<const Node*> consumer_nodes = graph.GetConsumerNodes(output->Name());
|
||||
if (consumer_nodes.size() > 0) {
|
||||
find_consumer_nodes = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!find_consumer_nodes && outputs.size() > 0) {
|
||||
input_args.push_back(outputs[0]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
NodeArg& CreateNodeArg(Graph& graph, const NodeArg* base_arg) {
|
||||
const auto& new_name = graph.GenerateNodeArgName(base_arg->Name());
|
||||
ONNX_NAMESPACE::TypeProto type_proto(*(base_arg->TypeAsProto()));
|
||||
if (graph.GetNodeArg(new_name) != nullptr) {
|
||||
ORT_THROW("Node with name ", new_name, " already exists.");
|
||||
}
|
||||
return graph.GetOrCreateNodeArg(new_name, &type_proto);
|
||||
}
|
||||
|
||||
Status AddRecordBackward(Graph& graph,
|
||||
Node* send_bw,
|
||||
std::vector<std::string>& new_input_names,
|
||||
std::vector<std::string>& new_output_names) {
|
||||
std::vector<NodeArg*> input_args;
|
||||
AddInputEvent(graph, "RecordEvent", false /* is_forward */, input_args, new_input_names);
|
||||
std::vector<NodeArg*> output_args{};
|
||||
|
||||
if (send_bw) {
|
||||
// if we have send op in backward pass (at the end of the graph), we make sure the RecordEvent happens
|
||||
// after that send by adding Send's outputs to RecordEvent's input list.
|
||||
input_args.insert(std::end(input_args),
|
||||
std::begin(send_bw->MutableOutputDefs()),
|
||||
std::end(send_bw->MutableOutputDefs()));
|
||||
}
|
||||
FindLeafNodes(graph, input_args);
|
||||
|
||||
// Optimizer will be added after applying pipeline transformer. To support partial graph evaluation,
|
||||
// the added Record backward op will have its first passthrough input as output.
|
||||
ORT_RETURN_IF_NOT(input_args.size() >= 2, "RecordEvent backward op at least have two inputs.")
|
||||
auto& new_output = CreateNodeArg(graph, input_args[1]); // the first input is signal, not passing through
|
||||
output_args.push_back(&new_output);
|
||||
new_output_names.push_back(new_output.Name());
|
||||
|
||||
graph.AddNode(graph.GenerateNodeName("RecordEvent"),
|
||||
"RecordEvent",
|
||||
"Backward pass",
|
||||
input_args,
|
||||
output_args,
|
||||
nullptr,
|
||||
kMSDomain);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AddWaitForward(Graph& graph, Node* /* recv_fw */, std::vector<std::string>& new_input_names) {
|
||||
// Append old_input to input_args and return its pass-through value. Note that
|
||||
// input_args and output_args are Wait's inputs and outputs, respectively.
|
||||
auto update_wait_input_output = [&](NodeArg* old_input,
|
||||
std::vector<NodeArg*>& input_args,
|
||||
std::vector<NodeArg*>& output_args) -> NodeArg& {
|
||||
input_args.push_back(old_input);
|
||||
|
||||
const auto& new_name = graph.GenerateNodeArgName(old_input->Name());
|
||||
ONNX_NAMESPACE::TypeProto type_proto(*(old_input->TypeAsProto()));
|
||||
|
||||
auto& wait_output = graph.GetOrCreateNodeArg(new_name, &type_proto);
|
||||
output_args.push_back(&wait_output);
|
||||
|
||||
return wait_output;
|
||||
};
|
||||
|
||||
std::vector<NodeArg*> input_args;
|
||||
std::vector<NodeArg*> output_args;
|
||||
AddInputEvent(graph, "WaitEvent", true /* is_forward */, input_args, new_input_names);
|
||||
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();
|
||||
|
||||
if (graph_inputs.size() == 0){
|
||||
ORT_THROW("Graph ", graph.Name(), " doesn't have any inputs.");
|
||||
}
|
||||
|
||||
for (auto& input_arg : graph_inputs) {
|
||||
NodeArg* mutable_input = graph.GetNodeArg(input_arg->Name());
|
||||
auto& wait_output = update_wait_input_output(mutable_input, input_args, output_args);
|
||||
std::vector<Node*> nodes = graph.GetMutableConsumerNodes(input_arg->Name());
|
||||
for (auto& consumer_node : nodes) {
|
||||
for (auto& i : consumer_node->MutableInputDefs()) {
|
||||
if (i->Name() == input_arg->Name()) {
|
||||
// if the node is fed by input, re-direct it to be fed by WaitEvent's output.
|
||||
i = &wait_output;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
graph.AddNode(graph.GenerateNodeName("WaitEvent"),
|
||||
"WaitEvent",
|
||||
"",
|
||||
input_args,
|
||||
output_args,
|
||||
nullptr,
|
||||
kMSDomain);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AddOrSkipRecordForwardWaitBackward(Graph& graph, Node* send_fw, Node* recv_bw, std::vector<std::string>& new_input_names) {
|
||||
if (!send_fw != !recv_bw){
|
||||
ORT_THROW("Graph requires either having both send forward node "
|
||||
"and recv backword node, or none of them. Currently the graph "
|
||||
"has send forward: ", send_fw, " and recv backward: ", recv_bw);
|
||||
}
|
||||
|
||||
if (!send_fw && !recv_bw){
|
||||
// Last partition doesn't have send forwrad and recv backward. No insert needed.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// if we have a send forward op followed by a recv backward op, insert WaitEvent and RecordEvent in between.
|
||||
Node* record_node = nullptr;
|
||||
Node* wait_node = nullptr;
|
||||
|
||||
// Insert RecordEvent
|
||||
{
|
||||
std::vector<NodeArg*> input_args;
|
||||
std::vector<NodeArg*> output_args;
|
||||
AddInputEvent(graph, "RecordEvent", true /* is_forward */, input_args, new_input_names);
|
||||
|
||||
// Add send forward op's output as record op's input and output
|
||||
for (auto& output : send_fw->MutableOutputDefs()) {
|
||||
auto& new_output = CreateNodeArg(graph, output);
|
||||
output_args.push_back(&new_output);
|
||||
input_args.push_back(output);
|
||||
}
|
||||
|
||||
auto& new_node = graph.AddNode(graph.GenerateNodeName("RecordEvent"),
|
||||
"RecordEvent",
|
||||
"",
|
||||
input_args,
|
||||
output_args, /* output */
|
||||
{}, /* attribute */
|
||||
kMSDomain);
|
||||
record_node = &new_node;
|
||||
}
|
||||
// Insert WaitEvent
|
||||
{
|
||||
std::vector<NodeArg*> input_args;
|
||||
std::vector<NodeArg*> output_args;
|
||||
AddInputEvent(graph, "WaitEvent", false /* is_forward */, input_args, new_input_names);
|
||||
|
||||
input_args.insert(std::end(input_args),
|
||||
std::begin(record_node->MutableOutputDefs()),
|
||||
std::end(record_node->MutableOutputDefs()));
|
||||
|
||||
auto& input = recv_bw->MutableInputDefs()[0];
|
||||
auto& new_output = CreateNodeArg(graph, input);
|
||||
output_args.push_back(&new_output);
|
||||
input = &new_output;
|
||||
|
||||
auto& new_node = graph.AddNode(graph.GenerateNodeName("WaitEvent"),
|
||||
"WaitEvent",
|
||||
"Backward pass",
|
||||
input_args,
|
||||
output_args, /* output */
|
||||
{}, /* attribute */
|
||||
kMSDomain);
|
||||
wait_node = &new_node;
|
||||
ORT_UNUSED_PARAMETER(wait_node);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TransformGraphForPipeline(Graph& graph) {
|
||||
// insert WaitEvent and RecordEvent to the partition
|
||||
Node* send_fw{nullptr};
|
||||
Node* send_bw{nullptr};
|
||||
Node* recv_fw{nullptr};
|
||||
Node* recv_bw{nullptr};
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Send") {
|
||||
if (IsBackward(node)) {
|
||||
send_bw = &node;
|
||||
} else {
|
||||
send_fw = &node;
|
||||
}
|
||||
} else if (node.OpType() == "Recv") {
|
||||
if (IsBackward(node)) {
|
||||
recv_bw = &node;
|
||||
} else {
|
||||
recv_fw = &node;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> new_input_names;
|
||||
std::vector<std::string> new_output_names;
|
||||
|
||||
ORT_RETURN_IF_ERROR(AddRecordBackward(graph, send_bw, new_input_names, new_output_names));
|
||||
ORT_RETURN_IF_ERROR(AddWaitForward(graph, recv_fw, new_input_names));
|
||||
ORT_RETURN_IF_ERROR(AddOrSkipRecordForwardWaitBackward(graph, send_fw, recv_bw, new_input_names));
|
||||
|
||||
auto fill_node_args = [&](const Graph& graph,
|
||||
const std::vector<const NodeArg*>& existed_node_args,
|
||||
std::vector<std::string>& new_node_arg_names,
|
||||
std::vector<const NodeArg*>& merged_node_args) {
|
||||
merged_node_args.insert(merged_node_args.end(), existed_node_args.begin(), existed_node_args.end());
|
||||
for (auto& name : new_node_arg_names) {
|
||||
merged_node_args.push_back(graph.GetNodeArg(name));
|
||||
}
|
||||
};
|
||||
|
||||
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();
|
||||
std::vector<const NodeArg*> inputs_args_sets;
|
||||
inputs_args_sets.reserve(graph_inputs.size() + new_input_names.size());
|
||||
fill_node_args(graph, graph_inputs, new_input_names, inputs_args_sets);
|
||||
|
||||
const std::vector<const NodeArg*>& graph_outputs = graph.GetOutputs();
|
||||
std::vector<const NodeArg*> outputs_args_sets;
|
||||
outputs_args_sets.reserve(graph_outputs.size() + new_output_names.size());
|
||||
fill_node_args(graph, graph_outputs, new_output_names, outputs_args_sets);
|
||||
|
||||
graph.SetInputs(inputs_args_sets);
|
||||
graph.SetOutputs(outputs_args_sets);
|
||||
graph.SetGraphResolveNeeded();
|
||||
graph.SetGraphProtoSyncNeeded();
|
||||
return graph.Resolve();
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
15
orttraining/orttraining/core/graph/pipeline_transformer.h
Normal file
15
orttraining/orttraining/core/graph/pipeline_transformer.h
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/graph/graph.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
||||
void GetPipelineSendOutput(const Graph& graph, std::string& loss_name);
|
||||
common::Status TransformGraphForPipeline(Graph& graph);
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -15,6 +15,7 @@
|
|||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
#include "orttraining/core/graph/mixed_precision_transformer.h"
|
||||
#include "orttraining/core/graph/tensorboard_transformer.h"
|
||||
#include "orttraining/core/graph/pipeline_transformer.h"
|
||||
#include "orttraining/core/graph/gradient_builder_base.h"
|
||||
|
||||
//Gist Encoding
|
||||
|
|
@ -128,15 +129,23 @@ Status TrainingSession::ConfigureForTraining(
|
|||
is_mixed_precision_enabled_ = config.mixed_precision_config.has_value();
|
||||
|
||||
std::string loss_name{};
|
||||
const optional<LossFunctionInfo> loss_function_info =
|
||||
config.loss_function_config.has_value()
|
||||
? config.loss_function_config.value().loss_function_info
|
||||
: optional<LossFunctionInfo>{};
|
||||
optional<std::string> loss_scale_input_name =
|
||||
is_mixed_precision_enabled_ ? optional<std::string>{""} : optional<std::string>{};
|
||||
ORT_RETURN_IF_ERROR(ConfigureLossFunction(
|
||||
config.loss_name, loss_function_info,
|
||||
loss_scale_input_name.has_value() ? &loss_scale_input_name.value() : nullptr, loss_name));
|
||||
is_mixed_precision_enabled_ ? optional<std::string>{""} : optional<std::string>{};
|
||||
if (config.use_pipeline) {
|
||||
// if use pipeline, first check if model contains send op. If it does, set the
|
||||
// send node's output as the start tensor to build gradient graph
|
||||
GetPipelineSendOutput(model_->MainGraph(), loss_name);
|
||||
}
|
||||
if (loss_name.empty()) {
|
||||
const optional<LossFunctionInfo> loss_function_info =
|
||||
config.loss_function_config.has_value()
|
||||
? config.loss_function_config.value().loss_function_info
|
||||
: optional<LossFunctionInfo>{};
|
||||
ORT_RETURN_IF_ERROR(ConfigureLossFunction(
|
||||
config.loss_name, loss_function_info,
|
||||
loss_scale_input_name.has_value() ? &loss_scale_input_name.value() : nullptr, loss_name));
|
||||
}
|
||||
|
||||
ORT_ENFORCE(
|
||||
!loss_scale_input_name.has_value() || !loss_scale_input_name.value().empty(),
|
||||
"loss_scale_input_name should not be set to an empty string.");
|
||||
|
|
@ -170,7 +179,6 @@ Status TrainingSession::ConfigureForTraining(
|
|||
<< weight_names_stream.str();
|
||||
}
|
||||
|
||||
// add gradient graph
|
||||
ORT_RETURN_IF_ERROR(BuildGradientGraph(
|
||||
weight_names_to_train, loss_name, config.set_gradients_as_graph_outputs));
|
||||
|
||||
|
|
@ -182,12 +190,30 @@ Status TrainingSession::ConfigureForTraining(
|
|||
weight_names_to_train, mixed_precision_config.use_fp16_initializers, fp32_weight_name_to_fp16_node_arg));
|
||||
}
|
||||
|
||||
if (config.use_pipeline) {
|
||||
ORT_RETURN_IF_ERROR(InsertPipelineOps());
|
||||
}
|
||||
|
||||
// All non-float tensors are not trainable. Remove those weights.
|
||||
// TODO: this is a temp workaround for removing rank tensor before adding optimizer.
|
||||
// Re-visit after we port logic for model splitting and hence know the rank tensor name.
|
||||
for (auto it = weights_to_train_.begin(); it != weights_to_train_.end();) {
|
||||
const auto* node_arg = model_->MainGraph().GetNodeArg(*it);
|
||||
ORT_RETURN_IF_NOT(node_arg, "Failed to get NodeArg with name ", *it);
|
||||
if (node_arg->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
it = weights_to_train_.erase(it);
|
||||
}
|
||||
else{
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
// add optimizer or gradient accumulation
|
||||
if (config.optimizer_config.has_value()) {
|
||||
OptimizerGraphConfig opt_graph_config{};
|
||||
std::unordered_map<std::string, OptimizerNodeConfig> opt_node_configs{};
|
||||
ORT_RETURN_IF_ERROR(SetupOptimizerParams(
|
||||
weight_names_to_train, fp32_weight_name_to_fp16_node_arg,
|
||||
weights_to_train_, fp32_weight_name_to_fp16_node_arg,
|
||||
loss_scale_input_name, config, opt_graph_config, opt_node_configs));
|
||||
|
||||
TrainingConfigurationResult::OptimizerConfigurationResult optimizer_config_result{};
|
||||
|
|
@ -198,7 +224,7 @@ Status TrainingSession::ConfigureForTraining(
|
|||
config_result.opt_config_result = optimizer_config_result;
|
||||
} else {
|
||||
if (config.gradient_accumulation_steps > 1) {
|
||||
ORT_RETURN_IF_ERROR(BuildAccumulationNode(weight_names_to_train));
|
||||
ORT_RETURN_IF_ERROR(BuildAccumulationNode(weights_to_train_));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -439,6 +465,11 @@ Status TrainingSession::AddTensorboard(const std::string& summary_name,
|
|||
return DoPostLoadProcessing(*model_);
|
||||
}
|
||||
|
||||
Status TrainingSession::InsertPipelineOps() {
|
||||
ORT_RETURN_IF_ERROR(TransformGraphForPipeline(model_->MainGraph()));
|
||||
return DoPostLoadProcessing(*model_);
|
||||
}
|
||||
|
||||
Status TrainingSession::ConfigureLossFunction(
|
||||
const optional<std::string>& external_loss_name,
|
||||
const optional<LossFunctionInfo>& loss_function_info,
|
||||
|
|
|
|||
|
|
@ -132,6 +132,9 @@ class TrainingSession : public InferenceSession {
|
|||
// The optimizer configuration.
|
||||
// If not provided, no optimizer is added.
|
||||
optional<OptimizerConfiguration> optimizer_config{};
|
||||
|
||||
// Whether to use pipeline in training.
|
||||
bool use_pipeline{false};
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -262,6 +265,7 @@ class TrainingSession : public InferenceSession {
|
|||
const std::vector<std::string>& norm_nodes,
|
||||
const bool dump_convergence_metrics);
|
||||
|
||||
common::Status InsertPipelineOps();
|
||||
common::Status ApplyTransformationsToMainGraph();
|
||||
|
||||
/** configure initial transformers for training */
|
||||
|
|
|
|||
|
|
@ -997,6 +997,106 @@ class PipelineBatchPlanner {
|
|||
}
|
||||
};
|
||||
|
||||
// verify pipeline config can load and gradient graph can construct.
|
||||
TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) {
|
||||
PathString filename_base = ORT_TSTR("testdata/test_training_model_");
|
||||
|
||||
auto load_gradient_graph = [](int stageIdx, PathString& input_file, PathString& output_file) {
|
||||
auto config = MakeBasicTrainingConfig();
|
||||
|
||||
config.use_pipeline = true;
|
||||
|
||||
PathString backprop_model_file;
|
||||
ASSERT_STATUS_OK(BuildBackPropGraph(input_file, config, backprop_model_file));
|
||||
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE(Model::Load(backprop_model_file, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
|
||||
|
||||
Graph& graph = model->MainGraph();
|
||||
auto is_backward = [](Node& node) {
|
||||
return (node.Description() == "Backward pass");
|
||||
};
|
||||
// check for wait/record node
|
||||
Node* wait_fw{nullptr};
|
||||
Node* wait_bw{nullptr};
|
||||
Node* record_fw{nullptr};
|
||||
Node* record_bw{nullptr};
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "WaitEvent") {
|
||||
if (is_backward(node)) {
|
||||
wait_bw = &node;
|
||||
} else {
|
||||
wait_fw = &node;
|
||||
}
|
||||
} else if (node.OpType() == "RecordEvent") {
|
||||
if (is_backward(node)) {
|
||||
record_bw = &node;
|
||||
} else {
|
||||
record_fw = &node;
|
||||
}
|
||||
}
|
||||
}
|
||||
// every partition should have wait forward and record backward
|
||||
ASSERT_TRUE(wait_fw && record_bw);
|
||||
if (stageIdx == 2) {
|
||||
// the last partition can perform back prop right away. It won't have record
|
||||
// forward and wait backward
|
||||
ASSERT_TRUE(!record_fw && !wait_bw);
|
||||
} else {
|
||||
ASSERT_TRUE(record_fw && wait_bw);
|
||||
}
|
||||
|
||||
// check for send/recv node
|
||||
Node* send_fw{nullptr};
|
||||
Node* send_bw{nullptr};
|
||||
Node* recv_fw{nullptr};
|
||||
Node* recv_bw{nullptr};
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Send") {
|
||||
if (is_backward(node)) {
|
||||
send_bw = &node;
|
||||
} else {
|
||||
send_fw = &node;
|
||||
}
|
||||
} else if (node.OpType() == "Recv") {
|
||||
if (is_backward(node)) {
|
||||
recv_bw = &node;
|
||||
} else {
|
||||
recv_fw = &node;
|
||||
}
|
||||
}
|
||||
}
|
||||
// except the last partion, each partition should have send forward and recv backward
|
||||
if (stageIdx == 0 || stageIdx == 1) {
|
||||
ASSERT_TRUE(send_fw && recv_bw);
|
||||
} else {
|
||||
ASSERT_TRUE(!send_fw && !recv_bw);
|
||||
}
|
||||
// except the first partion, each partition should have recv forward and send backward
|
||||
if (stageIdx == 1 || stageIdx == 2) {
|
||||
ASSERT_TRUE(recv_fw && send_bw);
|
||||
} else {
|
||||
ASSERT_TRUE(!recv_fw && !send_bw);
|
||||
}
|
||||
|
||||
auto mp = model->ToProto();
|
||||
std::ofstream ofs(output_file, std::ofstream::binary);
|
||||
mp.SerializeToOstream(&ofs);
|
||||
ofs.close();
|
||||
};
|
||||
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
#ifdef _WIN32
|
||||
auto surfix = std::to_wstring(i);
|
||||
#else
|
||||
auto surfix = std::to_string(i);
|
||||
#endif
|
||||
PathString input_file = filename_base + surfix + ORT_TSTR(".onnx");
|
||||
PathString output_file = filename_base + surfix + ORT_TSTR("_back.onnx");
|
||||
load_gradient_graph(i, input_file, output_file);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GradientGraphBuilderTest, TrainingSession_WithPipeline) {
|
||||
auto config = MakeBasicTrainingConfig();
|
||||
//config.set_gradients_as_graph_outputs = true;
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import onnx
|
|||
from onnx import helper
|
||||
from onnx import TensorProto
|
||||
from onnx import OperatorSetIdProto
|
||||
|
||||
# Edge that needs to be cut for the split.
|
||||
# If the edge is feeding into more than one nodes, and not all the nodes belong to the same cut,
|
||||
# specify those consuming nodes that need to be cut
|
||||
|
|
@ -29,31 +28,6 @@ def split_graph(model, split_edge_groups):
|
|||
|
||||
new_send_nodes = []
|
||||
new_recv_nodes = []
|
||||
# Add wait for initial inputs. This needs to be done first before new inputs
|
||||
# are introduced from split
|
||||
initializer_lists = [a.name for a in model.graph.initializer]
|
||||
input_tensors = [
|
||||
value.name for value in model.graph.input if value.name not in initializer_lists]
|
||||
|
||||
input_wait_signal = model.graph.input.add()
|
||||
input_wait_signal.CopyFrom(helper.make_tensor_value_info(
|
||||
'input_wait_signal', onnx.TensorProto.INT64, None))
|
||||
|
||||
input_wait = model.graph.node.add()
|
||||
input_wait.CopyFrom(helper.make_node(
|
||||
'WaitEvent',
|
||||
inputs=['input_wait_signal'],
|
||||
outputs=[],
|
||||
domain=ms_domain))
|
||||
|
||||
for i in input_tensors:
|
||||
for node in model.graph.node:
|
||||
for j in range(len(node.input)):
|
||||
if node.input[j] == i:
|
||||
node.input[j] = i + '_sync'
|
||||
|
||||
input_wait.input.extend(input_tensors)
|
||||
input_wait.output.extend([i + '_sync' for i in input_tensors])
|
||||
|
||||
for cut_index in range(len(split_edge_groups)):
|
||||
edgeIds = split_edge_groups[cut_index]
|
||||
|
|
@ -62,7 +36,7 @@ def split_graph(model, split_edge_groups):
|
|||
upstream_nodes = []
|
||||
upstream_nodes_output_index = []
|
||||
output_shapes = []
|
||||
|
||||
element_types = []
|
||||
for id in edgeIds:
|
||||
for node in model.graph.node:
|
||||
if len(node.output) >= 1:
|
||||
|
|
@ -70,25 +44,43 @@ def split_graph(model, split_edge_groups):
|
|||
if j == id:
|
||||
upstream_nodes.append(node)
|
||||
upstream_nodes_output_index.append(i)
|
||||
for info in model.graph.value_info:
|
||||
if info.name == id:
|
||||
output_shapes.append(info.type)
|
||||
|
||||
record_signal = model.graph.input.add()
|
||||
record_signal.CopyFrom(helper.make_tensor_value_info(
|
||||
'record_input_signal' + str(cut_index), onnx.TensorProto.INT64, None))
|
||||
|
||||
wait_signal = model.graph.input.add()
|
||||
wait_signal.CopyFrom(helper.make_tensor_value_info(
|
||||
'wait_input_signal' + str(cut_index), onnx.TensorProto.INT64, None))
|
||||
# assuming all tensors are of type float
|
||||
element_types.append(1)
|
||||
for info in model.graph.value_info:
|
||||
if info.name == id:
|
||||
output_shapes.append(info.type)
|
||||
|
||||
send_input_signal_name = 'send_input_signal' + str(cut_index)
|
||||
send_signal = model.graph.input.add()
|
||||
send_signal.CopyFrom(helper.make_tensor_value_info(
|
||||
'send_input_signal' + str(cut_index), onnx.TensorProto.BOOL, None))
|
||||
send_input_signal_name, onnx.TensorProto.BOOL, None))
|
||||
send_signal = helper.make_tensor(
|
||||
send_input_signal_name, TensorProto.BOOL, (), (True,))
|
||||
model.graph.initializer.extend([send_signal])
|
||||
|
||||
recv_input_signal_name = 'recv_input_signal' + str(cut_index)
|
||||
recv_signal = model.graph.input.add()
|
||||
recv_signal.CopyFrom(helper.make_tensor_value_info(
|
||||
'recv_input_signal' + str(cut_index), onnx.TensorProto.BOOL, None))
|
||||
recv_input_signal_name, onnx.TensorProto.BOOL, None))
|
||||
recv_signal = helper.make_tensor(
|
||||
recv_input_signal_name, TensorProto.BOOL, (), (True,))
|
||||
model.graph.initializer.extend([recv_signal])
|
||||
|
||||
send_dst_rank_name = 'send_dst_rank' + str(cut_index)
|
||||
send_dst_rank = model.graph.input.add()
|
||||
send_dst_rank.CopyFrom(helper.make_tensor_value_info(
|
||||
send_dst_rank_name, onnx.TensorProto.INT64, None))
|
||||
send_dst_rank = helper.make_tensor(
|
||||
send_dst_rank_name, TensorProto.INT64, (), (cut_index + 1,))
|
||||
model.graph.initializer.extend([send_dst_rank])
|
||||
|
||||
recv_src_rank_name = 'recv_src_rank' + str(cut_index)
|
||||
recv_src_rank = model.graph.input.add()
|
||||
recv_src_rank.CopyFrom(helper.make_tensor_value_info(
|
||||
recv_src_rank_name, onnx.TensorProto.INT64, None))
|
||||
recv_src_rank = helper.make_tensor(
|
||||
recv_src_rank_name, TensorProto.INT64, (), (cut_index,))
|
||||
model.graph.initializer.extend([recv_src_rank])
|
||||
|
||||
# output signal from send after cut
|
||||
send_output_signal = model.graph.output.add()
|
||||
|
|
@ -103,41 +95,23 @@ def split_graph(model, split_edge_groups):
|
|||
new_send = model.graph.node.add()
|
||||
new_send.CopyFrom(helper.make_node(
|
||||
'Send',
|
||||
inputs=['send_input_signal' + str(cut_index)],
|
||||
inputs=[send_input_signal_name, send_dst_rank_name],
|
||||
outputs=['send_output_signal' + str(cut_index)],
|
||||
tag=0,
|
||||
src=cut_index,
|
||||
dst=cut_index + 1,
|
||||
domain=ms_domain,
|
||||
element_type=7, # assuming all tensors are of type float
|
||||
element_types=element_types,
|
||||
name='send'))
|
||||
|
||||
new_receive = model.graph.node.add()
|
||||
new_receive.CopyFrom(helper.make_node(
|
||||
'Recv',
|
||||
inputs=['recv_input_signal' + str(cut_index)],
|
||||
inputs=[recv_input_signal_name, recv_src_rank_name],
|
||||
outputs=['receive_output_signal' + str(cut_index)],
|
||||
tag=1,
|
||||
src=cut_index,
|
||||
dst=cut_index + 1,
|
||||
tag=0,
|
||||
domain=ms_domain,
|
||||
element_type=7, # assuming all tensors are of type float
|
||||
element_types=element_types,
|
||||
name='receive'))
|
||||
|
||||
new_wait = model.graph.node.add()
|
||||
new_wait.CopyFrom(helper.make_node(
|
||||
'WaitEvent',
|
||||
inputs=['wait_input_signal' + str(cut_index)],
|
||||
outputs=[],
|
||||
domain=ms_domain))
|
||||
|
||||
new_record = model.graph.node.add()
|
||||
new_record.CopyFrom(helper.make_node(
|
||||
'RecordEvent',
|
||||
inputs=['record_input_signal' + str(cut_index)],
|
||||
outputs=[],
|
||||
domain=ms_domain))
|
||||
|
||||
for i in range(len(upstream_nodes)):
|
||||
n = upstream_nodes[i]
|
||||
idx = upstream_nodes_output_index[i]
|
||||
|
|
@ -155,24 +129,16 @@ def split_graph(model, split_edge_groups):
|
|||
'_recv' + str(cut_index)
|
||||
add_expand_type(model, new_receive_output_name, output_type)
|
||||
|
||||
new_wait_output_name = output_edge_name + '_wait' + str(cut_index)
|
||||
add_expand_type(model, new_wait_output_name, output_type)
|
||||
|
||||
# the order of data flow is: node-output -> record -> send -> recv -> wait -> node-input
|
||||
new_record.input.extend([output_edge_name])
|
||||
new_record.output.extend([new_send_input_name])
|
||||
|
||||
new_send.input.extend([new_send_input_name])
|
||||
new_send.input.extend([output_edge_name])
|
||||
new_receive.output.extend([new_receive_output_name])
|
||||
|
||||
new_wait.input.extend([new_receive_output_name])
|
||||
new_wait.output.extend([new_wait_output_name])
|
||||
|
||||
for output_node in output_nodes:
|
||||
for i in range(len(output_node.input)):
|
||||
for edgeId in edgeIds:
|
||||
if output_node.input[i] == edgeId:
|
||||
output_node.input[i] = new_wait_output_name
|
||||
output_node.input[i] = new_receive_output_name
|
||||
|
||||
new_send_nodes.append(new_send)
|
||||
new_recv_nodes.append(new_receive)
|
||||
|
|
@ -236,9 +202,50 @@ def add_identity(model, cuttingEdge, newEdgeIdName):
|
|||
if output_nodes[i].input[j] == edgeId:
|
||||
output_nodes[i].input[j] = newEdgeIdName
|
||||
|
||||
return newEdgeIdName
|
||||
return new_identity
|
||||
|
||||
|
||||
def insert_identity(model, all_cut_inputs):
|
||||
count = 0
|
||||
updated_edges = {}
|
||||
new_added_identity = []
|
||||
split_edge_groups = []
|
||||
need_shape_inference = False
|
||||
# Sweep the cut edge to see if there are edges feeding into nodes from two sub-graphs. If so,
|
||||
# insert identity node after those edges with a new ID to distinguish the rest.
|
||||
for cut_input in all_cut_inputs:
|
||||
split_edges = []
|
||||
for i in cut_input:
|
||||
if i.consumingNodes:
|
||||
# if this edge has previously been modified, update its edgeId before inserting new identity
|
||||
if i.edgeId in updated_edges:
|
||||
i.edgeId = updated_edges[i.edgeId]
|
||||
|
||||
new_edge_name = 'identity_output_' + str(count)
|
||||
new_added_identity.append(
|
||||
add_identity(model, i, new_edge_name))
|
||||
count += 1
|
||||
split_edges.append(new_edge_name)
|
||||
updated_edges[i.edgeId] = new_edge_name
|
||||
need_shape_inference = True
|
||||
else:
|
||||
split_edges.append(i.edgeId)
|
||||
split_edge_groups.append(split_edges)
|
||||
return split_edge_groups, new_added_identity, need_shape_inference
|
||||
|
||||
# after the graph is split, remove the added identity node because identity op is not registered in gradient builder.
|
||||
|
||||
|
||||
def remove_identity(model, new_added_identity):
|
||||
for node in new_added_identity:
|
||||
assert node.op_type == 'Identity'
|
||||
output_nodes = [
|
||||
n for n in model.graph.node if node.output[0] in n.input]
|
||||
for output_node in output_nodes:
|
||||
for i in range(len(output_node.input)):
|
||||
if output_node.input[i] == node.output[0]:
|
||||
output_node.input[i] = node.input[0]
|
||||
|
||||
def find_all_connected_nodes(model, node):
|
||||
nodes0, inputs = find_all_input_nodes(model, node)
|
||||
nodes1, outputs = find_all_output_nodes(model, node)
|
||||
|
|
@ -251,15 +258,36 @@ def get_index(node_list, node):
|
|||
found = [i for i, n in enumerate(node_list) if n == node]
|
||||
return found[0] if found else None
|
||||
|
||||
def get_identity_index_for_deleting(node_list, node):
|
||||
for i, n in enumerate(node_list):
|
||||
# The node's input name has been changed during send/recv insertion,
|
||||
# but it is sufficient to just compare the type and outputs.
|
||||
if (n.op_type == 'Identity' and n.output == node.output):
|
||||
return i
|
||||
return None
|
||||
|
||||
# traverse the graph, group connected nodes and generate subgraph
|
||||
|
||||
|
||||
def generate_subgraph(model, start_nodes):
|
||||
def generate_subgraph(model, start_nodes, identity_node_list):
|
||||
subgraphs = []
|
||||
|
||||
main_graph = onnx.ModelProto()
|
||||
main_graph.CopyFrom(model)
|
||||
|
||||
# remove added identity node before copy to subgraph
|
||||
identity_node_index = []
|
||||
for n in identity_node_list:
|
||||
identity_node_index.append(get_identity_index_for_deleting(main_graph.graph.node, n))
|
||||
identity_node_index.sort(reverse=True)
|
||||
|
||||
for i in reversed(range(len(main_graph.graph.node))):
|
||||
try:
|
||||
if i in identity_node_index:
|
||||
del main_graph.graph.node[i]
|
||||
except:
|
||||
print("error deleting identity node", i)
|
||||
|
||||
all_visited_nodes = []
|
||||
model_count = len(start_nodes)
|
||||
for start in reversed(start_nodes):
|
||||
|
|
@ -362,29 +390,8 @@ def main():
|
|||
output_model_names = [os.path.splitext(input_model_name)[0] + '_' +
|
||||
str(i) + '.onnx' for i in range(stage_count)]
|
||||
|
||||
split_edge_groups = []
|
||||
count = 0
|
||||
updated_edges = {}
|
||||
need_shape_inference = False
|
||||
# Sweep the cut edge to see if there are edges feeding into nodes from two sub-graphs. If so,
|
||||
# insert identity node after those edges with a new ID to distinguish the rest.
|
||||
for cut_input in all_cut_inputs:
|
||||
split_edges = []
|
||||
for i in cut_input:
|
||||
if i.consumingNodes:
|
||||
# if this edge has previously been modified, update its edgeId before inserting new identity
|
||||
if i.edgeId in updated_edges:
|
||||
i.edgeId = updated_edges[i.edgeId]
|
||||
split_edge_groups, new_identity, need_shape_inference = insert_identity(model, all_cut_inputs)
|
||||
|
||||
new_edge_name = 'identity_output_' + str(count)
|
||||
add_identity(model, i, new_edge_name)
|
||||
count += 1
|
||||
split_edges.append(new_edge_name)
|
||||
updated_edges[i.edgeId] = new_edge_name
|
||||
need_shape_inference = True
|
||||
else:
|
||||
split_edges.append(i.edgeId)
|
||||
split_edge_groups.append(split_edges)
|
||||
|
||||
# new edge is being added, need to re-inference shape
|
||||
if need_shape_inference:
|
||||
|
|
@ -392,11 +399,13 @@ def main():
|
|||
|
||||
# after all need-to-be-cut edges identified, split the graph
|
||||
new_sends, new_receives = split_graph(model, split_edge_groups)
|
||||
sub_graphs = generate_subgraph(model, new_receives)
|
||||
remove_identity(model, new_identity)
|
||||
sub_graphs = generate_subgraph(model, new_receives, new_identity)
|
||||
|
||||
for i in range(stage_count):
|
||||
sub_graphs[i] = onnx.shape_inference.infer_shapes(sub_graphs[i])
|
||||
onnx.save(sub_graphs[i], output_model_names[i])
|
||||
print("save to file: ", output_model_names[i])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in a new issue