mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-09 00:30:53 +00:00
bugfix for graph inputs and outputs.
This commit is contained in:
parent
b7564d0732
commit
e759da178d
2 changed files with 46 additions and 12 deletions
|
|
@ -53,6 +53,7 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
|
|||
ORT_RETURN_IF_ERROR(Model::Load(model_proto, model_, nullptr, *logger_));
|
||||
ORT_RETURN_IF_ERROR(model_->MainGraph().Resolve());
|
||||
|
||||
// Register and apply transformers for pre-training.
|
||||
const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration graph_transformer_config{};
|
||||
GraphTransformerManager graph_transformation_mgr{2};
|
||||
std::unique_ptr<CPUExecutionProvider> cpu_execution_provider =
|
||||
|
|
@ -73,14 +74,14 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
|
|||
}
|
||||
}
|
||||
|
||||
// apply transformers
|
||||
Graph& graph = model_->MainGraph();
|
||||
for (int i = static_cast<int>(TransformerLevel::Level1); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
|
||||
ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, static_cast<TransformerLevel>(i), *logger_));
|
||||
}
|
||||
|
||||
// TODO: mixed precision transformer.
|
||||
|
||||
|
||||
// Build gradient graph.
|
||||
GradientGraphConfiguration gradient_graph_config{};
|
||||
gradient_graph_config.use_invertible_layernorm_grad = config.use_invertible_layernorm_grad;
|
||||
gradient_graph_config.set_gradients_as_graph_outputs = config.set_gradients_as_graph_outputs;
|
||||
|
|
@ -92,7 +93,7 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
|
|||
*logger_);
|
||||
ORT_RETURN_IF_ERROR(grad_graph_builder.Build());
|
||||
|
||||
// Fix inputs/outputs related to gradient.
|
||||
// Fix inputs/outputs related to gradients.
|
||||
Graph& gradient_graph = model_->MainGraph();
|
||||
GraphViewer gradient_graph_viewer(gradient_graph);
|
||||
const auto& node_topology_list = gradient_graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
|
@ -104,20 +105,20 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
|
|||
}
|
||||
|
||||
const std::vector<const NodeArg*>& gradient_graph_inputs = gradient_graph.GetInputsIncludingInitializers();
|
||||
std::vector<std::string> graph_input_names;
|
||||
std::vector<std::string> graph_input_names;
|
||||
std::vector<const NodeArg*> input_args;
|
||||
for (auto& node_arg : gradient_graph_inputs) {
|
||||
input_args.push_back(node_arg);
|
||||
graph_input_names.push_back(node_arg->Name());
|
||||
}
|
||||
|
||||
// Add the entry points of gradients (normally loss_gard) to the graph inputs.
|
||||
for (const auto& output_name : config.output_names) {
|
||||
std::string output_gradient_name = output_name + "_grad";
|
||||
if (input_names.find(output_gradient_name) != input_names.end()) {
|
||||
if (input_names.find(output_gradient_name) != input_names.end() &&
|
||||
output_names.find(output_gradient_name) == output_names.end()) {
|
||||
NodeArg* output_gradient_node_arg = gradient_graph.GetNodeArg(output_gradient_name);
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
output_gradient_node_arg->UpdateTypeAndShape(*gradient_graph.GetNodeArg(output_name), true, true, *logger_);
|
||||
#endif
|
||||
input_args.push_back(output_gradient_node_arg);
|
||||
}
|
||||
}
|
||||
|
|
@ -130,6 +131,7 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
|
|||
output_args.push_back(node_arg);
|
||||
}
|
||||
|
||||
// Add weight gradients to graph outputs.
|
||||
for (const auto& weight_name : config.weight_names_to_train) {
|
||||
std::string weight_gradient_name = weight_name + "_grad";
|
||||
if (output_names.find(weight_gradient_name) != output_names.end()) {
|
||||
|
|
@ -137,6 +139,7 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
|
|||
}
|
||||
}
|
||||
|
||||
// Add input gradients to graph outputs if it's calculated.
|
||||
for (const auto& graph_input_name : graph_input_names) {
|
||||
std::string input_gradient_name = graph_input_name + "_grad";
|
||||
if (output_names.find(input_gradient_name) != output_names.end()) {
|
||||
|
|
@ -145,14 +148,18 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
|
|||
}
|
||||
|
||||
gradient_graph.SetOutputs(output_args);
|
||||
|
||||
|
||||
gradient_graph.Resolve();
|
||||
|
||||
// Create two copies of gradient model for forward and backward models respectively.
|
||||
auto gradient_model_proto = model_->ToProto();
|
||||
ORT_RETURN_IF_ERROR(Model::Load(gradient_model_proto, forward_model_, nullptr, *logger_));
|
||||
ORT_RETURN_IF_ERROR(Model::Load(gradient_model_proto, backward_model_, nullptr, *logger_));
|
||||
|
||||
// Split the graph in the copies of gradient model.
|
||||
ORT_RETURN_IF_ERROR(Split(config));
|
||||
|
||||
// Serialize the models as output to frontend.
|
||||
std::string gradient_model_str;
|
||||
if (!model_->ToProto().SerializeToString(&gradient_model_str)) {
|
||||
return Status(ONNXRUNTIME, FAIL, "Fail to serialize gradient model to string.");
|
||||
|
|
@ -214,6 +221,7 @@ Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfigu
|
|||
}
|
||||
}
|
||||
|
||||
// Add weights to forward graph inputs.
|
||||
for (const auto& weight_name : config.weight_names_to_train) {
|
||||
forward_input_args.push_back(forward_graph.GetNodeArg(weight_name));
|
||||
}
|
||||
|
|
@ -225,12 +233,17 @@ Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfigu
|
|||
forward_output_args.push_back(forward_graph.GetNodeArg(output_name));
|
||||
}
|
||||
|
||||
// Add intermediate args to forward graph outputs.
|
||||
for (const auto& intermediate_arg_name : intermediate_arg_names) {
|
||||
forward_output_args.push_back(forward_graph.GetNodeArg(intermediate_arg_name));
|
||||
// Ignore those duplicates.
|
||||
if (config.output_names.find(intermediate_arg_name) == config.output_names.end()) {
|
||||
forward_output_args.push_back(forward_graph.GetNodeArg(intermediate_arg_name));
|
||||
}
|
||||
}
|
||||
|
||||
forward_graph.SetOutputs(forward_output_args);
|
||||
|
||||
// Resolve the forward graph, keep the weight initializers for now.
|
||||
Graph::ResolveOptions options;
|
||||
options.initializer_names_to_preserve = &config.weight_names_to_train;
|
||||
forward_graph.Resolve(options);
|
||||
|
|
@ -258,6 +271,7 @@ Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfigu
|
|||
}
|
||||
}
|
||||
|
||||
// Add weight args to backward graph inputs if any node uses them.
|
||||
for (const auto& weight_name : config.weight_names_to_train) {
|
||||
// Weights will be inputs for backward graph.
|
||||
if (backward_input_names.find(weight_name) != backward_input_names.end()) {
|
||||
|
|
@ -266,11 +280,10 @@ Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfigu
|
|||
}
|
||||
}
|
||||
|
||||
// Add intermediate args to backward graph inputs.
|
||||
for (const auto& intermediate_arg_name : intermediate_arg_names) {
|
||||
NodeArg* intermediate_node_arg = backward_graph.GetNodeArg(intermediate_arg_name);
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
intermediate_node_arg->UpdateTypeAndShape(*forward_graph.GetNodeArg(intermediate_arg_name), true, true, *logger_);
|
||||
#endif
|
||||
backward_input_args.push_back(intermediate_node_arg);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -136,9 +136,10 @@ onnx.save(models[0], 'minst_gradient_graph.onnx')
|
|||
onnx.save(models[1], 'mnist_forward.onnx')
|
||||
onnx.save(models[2], 'mnist_backward.onnx')
|
||||
|
||||
|
||||
"""
|
||||
#BERT
|
||||
original_model = onnx.load('bert-tiny.onnx')
|
||||
original_model = onnx.load('BertForSequenceClassification_full_training.onnx')
|
||||
config = C.ModuleGradientGraphBuilderConfiguration()
|
||||
weight_names_to_train = set()
|
||||
for initializer in original_model.graph.initializer:
|
||||
|
|
@ -149,6 +150,26 @@ for output in original_model.graph.output:
|
|||
output_names.add(output.name)
|
||||
config.output_names = output_names
|
||||
|
||||
models = [onnx.load_model_from_string(model_as_string) for model_as_string in C.ModuleGradientGraphBuilder().build_and_split(original_model.SerializeToString(), config)]
|
||||
onnx.save(models[0], 'bert_gradient_graph.onnx')
|
||||
onnx.save(models[1], 'bert_forward.onnx')
|
||||
onnx.save(models[2], 'bert_backward.onnx')
|
||||
|
||||
|
||||
#BERT with loss
|
||||
original_model = onnx.load('bert-tiny-loss.onnx')
|
||||
config = C.ModuleGradientGraphBuilderConfiguration()
|
||||
weight_names_to_train = set()
|
||||
for initializer in original_model.graph.initializer:
|
||||
if initializer.name.startswith('bert.') or initializer.name.startswith('cls.'):
|
||||
weight_names_to_train.add(initializer.name)
|
||||
config.weight_names_to_train = weight_names_to_train
|
||||
output_names = set()
|
||||
output_names.add('total_loss')
|
||||
#for output in original_model.graph.output:
|
||||
# output_names.add(output.name)
|
||||
config.output_names = output_names
|
||||
|
||||
models = [onnx.load_model_from_string(model_as_string) for model_as_string in C.ModuleGradientGraphBuilder().build_and_split(original_model.SerializeToString(), config)]
|
||||
onnx.save(models[0], 'bert_gradient_graph.onnx')
|
||||
onnx.save(models[1], 'bert_forward.onnx')
|
||||
|
|
|
|||
Loading…
Reference in a new issue