bugfix for graph inputs and outputs.

This commit is contained in:
Vincent Wang 2020-11-09 07:07:11 +00:00 committed by Thiago Crepaldi
parent b7564d0732
commit e759da178d
2 changed files with 46 additions and 12 deletions

View file

@ -53,6 +53,7 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
ORT_RETURN_IF_ERROR(Model::Load(model_proto, model_, nullptr, *logger_));
ORT_RETURN_IF_ERROR(model_->MainGraph().Resolve());
// Register and apply transformers for pre-training.
const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration graph_transformer_config{};
GraphTransformerManager graph_transformation_mgr{2};
std::unique_ptr<CPUExecutionProvider> cpu_execution_provider =
@ -73,14 +74,14 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
}
}
// apply transformers
Graph& graph = model_->MainGraph();
for (int i = static_cast<int>(TransformerLevel::Level1); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, static_cast<TransformerLevel>(i), *logger_));
}
// TODO: mixed precision transformer.
// Build gradient graph.
GradientGraphConfiguration gradient_graph_config{};
gradient_graph_config.use_invertible_layernorm_grad = config.use_invertible_layernorm_grad;
gradient_graph_config.set_gradients_as_graph_outputs = config.set_gradients_as_graph_outputs;
@ -92,7 +93,7 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
*logger_);
ORT_RETURN_IF_ERROR(grad_graph_builder.Build());
// Fix inputs/outputs related to gradient.
// Fix inputs/outputs related to gradients.
Graph& gradient_graph = model_->MainGraph();
GraphViewer gradient_graph_viewer(gradient_graph);
const auto& node_topology_list = gradient_graph_viewer.GetNodesInTopologicalOrder();
@ -104,20 +105,20 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
}
const std::vector<const NodeArg*>& gradient_graph_inputs = gradient_graph.GetInputsIncludingInitializers();
std::vector<std::string> graph_input_names;
std::vector<std::string> graph_input_names;
std::vector<const NodeArg*> input_args;
for (auto& node_arg : gradient_graph_inputs) {
input_args.push_back(node_arg);
graph_input_names.push_back(node_arg->Name());
}
// Add the entry points of gradients (normally loss_gard) to the graph inputs.
for (const auto& output_name : config.output_names) {
std::string output_gradient_name = output_name + "_grad";
if (input_names.find(output_gradient_name) != input_names.end()) {
if (input_names.find(output_gradient_name) != input_names.end() &&
output_names.find(output_gradient_name) == output_names.end()) {
NodeArg* output_gradient_node_arg = gradient_graph.GetNodeArg(output_gradient_name);
#if !defined(ORT_MINIMAL_BUILD)
output_gradient_node_arg->UpdateTypeAndShape(*gradient_graph.GetNodeArg(output_name), true, true, *logger_);
#endif
input_args.push_back(output_gradient_node_arg);
}
}
@ -130,6 +131,7 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
output_args.push_back(node_arg);
}
// Add weight gradients to graph outputs.
for (const auto& weight_name : config.weight_names_to_train) {
std::string weight_gradient_name = weight_name + "_grad";
if (output_names.find(weight_gradient_name) != output_names.end()) {
@ -137,6 +139,7 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
}
}
// Add input gradients to graph outputs if it's calculated.
for (const auto& graph_input_name : graph_input_names) {
std::string input_gradient_name = graph_input_name + "_grad";
if (output_names.find(input_gradient_name) != output_names.end()) {
@ -145,14 +148,18 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
}
gradient_graph.SetOutputs(output_args);
gradient_graph.Resolve();
// Create two copies of gradient model for forward and backward models respectively.
auto gradient_model_proto = model_->ToProto();
ORT_RETURN_IF_ERROR(Model::Load(gradient_model_proto, forward_model_, nullptr, *logger_));
ORT_RETURN_IF_ERROR(Model::Load(gradient_model_proto, backward_model_, nullptr, *logger_));
// Split the graph in the copies of gradient model.
ORT_RETURN_IF_ERROR(Split(config));
// Serialize the models as output to frontend.
std::string gradient_model_str;
if (!model_->ToProto().SerializeToString(&gradient_model_str)) {
return Status(ONNXRUNTIME, FAIL, "Fail to serialize gradient model to string.");
@ -214,6 +221,7 @@ Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfigu
}
}
// Add weights to forward graph inputs.
for (const auto& weight_name : config.weight_names_to_train) {
forward_input_args.push_back(forward_graph.GetNodeArg(weight_name));
}
@ -225,12 +233,17 @@ Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfigu
forward_output_args.push_back(forward_graph.GetNodeArg(output_name));
}
// Add intermediate args to forward graph outputs.
for (const auto& intermediate_arg_name : intermediate_arg_names) {
forward_output_args.push_back(forward_graph.GetNodeArg(intermediate_arg_name));
// Ignore those duplicates.
if (config.output_names.find(intermediate_arg_name) == config.output_names.end()) {
forward_output_args.push_back(forward_graph.GetNodeArg(intermediate_arg_name));
}
}
forward_graph.SetOutputs(forward_output_args);
// Resolve the forward graph, keep the weight initializers for now.
Graph::ResolveOptions options;
options.initializer_names_to_preserve = &config.weight_names_to_train;
forward_graph.Resolve(options);
@ -258,6 +271,7 @@ Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfigu
}
}
// Add weight args to backward graph inputs if any node uses them.
for (const auto& weight_name : config.weight_names_to_train) {
// Weights will be inputs for backward graph.
if (backward_input_names.find(weight_name) != backward_input_names.end()) {
@ -266,11 +280,10 @@ Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfigu
}
}
// Add intermediate args to backward graph inputs.
for (const auto& intermediate_arg_name : intermediate_arg_names) {
NodeArg* intermediate_node_arg = backward_graph.GetNodeArg(intermediate_arg_name);
#if !defined(ORT_MINIMAL_BUILD)
intermediate_node_arg->UpdateTypeAndShape(*forward_graph.GetNodeArg(intermediate_arg_name), true, true, *logger_);
#endif
backward_input_args.push_back(intermediate_node_arg);
}

View file

@ -136,9 +136,10 @@ onnx.save(models[0], 'minst_gradient_graph.onnx')
onnx.save(models[1], 'mnist_forward.onnx')
onnx.save(models[2], 'mnist_backward.onnx')
"""
#BERT
original_model = onnx.load('bert-tiny.onnx')
original_model = onnx.load('BertForSequenceClassification_full_training.onnx')
config = C.ModuleGradientGraphBuilderConfiguration()
weight_names_to_train = set()
for initializer in original_model.graph.initializer:
@ -149,6 +150,26 @@ for output in original_model.graph.output:
output_names.add(output.name)
config.output_names = output_names
models = [onnx.load_model_from_string(model_as_string) for model_as_string in C.ModuleGradientGraphBuilder().build_and_split(original_model.SerializeToString(), config)]
onnx.save(models[0], 'bert_gradient_graph.onnx')
onnx.save(models[1], 'bert_forward.onnx')
onnx.save(models[2], 'bert_backward.onnx')
#BERT with loss
original_model = onnx.load('bert-tiny-loss.onnx')
config = C.ModuleGradientGraphBuilderConfiguration()
weight_names_to_train = set()
for initializer in original_model.graph.initializer:
if initializer.name.startswith('bert.') or initializer.name.startswith('cls.'):
weight_names_to_train.add(initializer.name)
config.weight_names_to_train = weight_names_to_train
output_names = set()
output_names.add('total_loss')
#for output in original_model.graph.output:
# output_names.add(output.name)
config.output_names = output_names
models = [onnx.load_model_from_string(model_as_string) for model_as_string in C.ModuleGradientGraphBuilder().build_and_split(original_model.SerializeToString(), config)]
onnx.save(models[0], 'bert_gradient_graph.onnx')
onnx.save(models[1], 'bert_forward.onnx')