diff --git a/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc b/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc index d2bcc70ac0..19d6a92e72 100644 --- a/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc +++ b/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc @@ -53,6 +53,7 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream, ORT_RETURN_IF_ERROR(Model::Load(model_proto, model_, nullptr, *logger_)); ORT_RETURN_IF_ERROR(model_->MainGraph().Resolve()); + // Register and apply transformers for pre-training. const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration graph_transformer_config{}; GraphTransformerManager graph_transformation_mgr{2}; std::unique_ptr cpu_execution_provider = @@ -73,14 +74,14 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream, } } - // apply transformers Graph& graph = model_->MainGraph(); for (int i = static_cast(TransformerLevel::Level1); i <= static_cast(TransformerLevel::MaxLevel); i++) { ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, static_cast(i), *logger_)); } // TODO: mixed precision transformer. - + + // Build gradient graph. GradientGraphConfiguration gradient_graph_config{}; gradient_graph_config.use_invertible_layernorm_grad = config.use_invertible_layernorm_grad; gradient_graph_config.set_gradients_as_graph_outputs = config.set_gradients_as_graph_outputs; @@ -92,7 +93,7 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream, *logger_); ORT_RETURN_IF_ERROR(grad_graph_builder.Build()); - // Fix inputs/outputs related to gradient. + // Fix inputs/outputs related to gradients. Graph& gradient_graph = model_->MainGraph(); GraphViewer gradient_graph_viewer(gradient_graph); const auto& node_topology_list = gradient_graph_viewer.GetNodesInTopologicalOrder(); @@ -104,20 +105,20 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream, } const std::vector& gradient_graph_inputs = gradient_graph.GetInputsIncludingInitializers(); - std::vector graph_input_names; + std::vector graph_input_names; std::vector input_args; for (auto& node_arg : gradient_graph_inputs) { input_args.push_back(node_arg); graph_input_names.push_back(node_arg->Name()); } + // Add the entry points of gradients (normally loss_gard) to the graph inputs. for (const auto& output_name : config.output_names) { std::string output_gradient_name = output_name + "_grad"; - if (input_names.find(output_gradient_name) != input_names.end()) { + if (input_names.find(output_gradient_name) != input_names.end() && + output_names.find(output_gradient_name) == output_names.end()) { NodeArg* output_gradient_node_arg = gradient_graph.GetNodeArg(output_gradient_name); -#if !defined(ORT_MINIMAL_BUILD) output_gradient_node_arg->UpdateTypeAndShape(*gradient_graph.GetNodeArg(output_name), true, true, *logger_); -#endif input_args.push_back(output_gradient_node_arg); } } @@ -130,6 +131,7 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream, output_args.push_back(node_arg); } + // Add weight gradients to graph outputs. for (const auto& weight_name : config.weight_names_to_train) { std::string weight_gradient_name = weight_name + "_grad"; if (output_names.find(weight_gradient_name) != output_names.end()) { @@ -137,6 +139,7 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream, } } + // Add input gradients to graph outputs if it's calculated. for (const auto& graph_input_name : graph_input_names) { std::string input_gradient_name = graph_input_name + "_grad"; if (output_names.find(input_gradient_name) != output_names.end()) { @@ -145,14 +148,18 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream, } gradient_graph.SetOutputs(output_args); - + gradient_graph.Resolve(); + // Create two copies of gradient model for forward and backward models respectively. auto gradient_model_proto = model_->ToProto(); ORT_RETURN_IF_ERROR(Model::Load(gradient_model_proto, forward_model_, nullptr, *logger_)); ORT_RETURN_IF_ERROR(Model::Load(gradient_model_proto, backward_model_, nullptr, *logger_)); + + // Split the graph in the copies of gradient model. ORT_RETURN_IF_ERROR(Split(config)); + // Serialize the models as output to frontend. std::string gradient_model_str; if (!model_->ToProto().SerializeToString(&gradient_model_str)) { return Status(ONNXRUNTIME, FAIL, "Fail to serialize gradient model to string."); @@ -214,6 +221,7 @@ Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfigu } } + // Add weights to forward graph inputs. for (const auto& weight_name : config.weight_names_to_train) { forward_input_args.push_back(forward_graph.GetNodeArg(weight_name)); } @@ -225,12 +233,17 @@ Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfigu forward_output_args.push_back(forward_graph.GetNodeArg(output_name)); } + // Add intermediate args to forward graph outputs. for (const auto& intermediate_arg_name : intermediate_arg_names) { - forward_output_args.push_back(forward_graph.GetNodeArg(intermediate_arg_name)); + // Ignore those duplicates. + if (config.output_names.find(intermediate_arg_name) == config.output_names.end()) { + forward_output_args.push_back(forward_graph.GetNodeArg(intermediate_arg_name)); + } } forward_graph.SetOutputs(forward_output_args); + // Resolve the forward graph, keep the weight initializers for now. Graph::ResolveOptions options; options.initializer_names_to_preserve = &config.weight_names_to_train; forward_graph.Resolve(options); @@ -258,6 +271,7 @@ Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfigu } } + // Add weight args to backward graph inputs if any node uses them. for (const auto& weight_name : config.weight_names_to_train) { // Weights will be inputs for backward graph. if (backward_input_names.find(weight_name) != backward_input_names.end()) { @@ -266,11 +280,10 @@ Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfigu } } + // Add intermediate args to backward graph inputs. for (const auto& intermediate_arg_name : intermediate_arg_names) { NodeArg* intermediate_node_arg = backward_graph.GetNodeArg(intermediate_arg_name); -#if !defined(ORT_MINIMAL_BUILD) intermediate_node_arg->UpdateTypeAndShape(*forward_graph.GetNodeArg(intermediate_arg_name), true, true, *logger_); -#endif backward_input_args.push_back(intermediate_node_arg); } diff --git a/samples/python/mnist/graph_spliter.py b/samples/python/mnist/graph_spliter.py index 2f90bcde10..90a76126d3 100644 --- a/samples/python/mnist/graph_spliter.py +++ b/samples/python/mnist/graph_spliter.py @@ -136,9 +136,10 @@ onnx.save(models[0], 'minst_gradient_graph.onnx') onnx.save(models[1], 'mnist_forward.onnx') onnx.save(models[2], 'mnist_backward.onnx') + """ #BERT -original_model = onnx.load('bert-tiny.onnx') +original_model = onnx.load('BertForSequenceClassification_full_training.onnx') config = C.ModuleGradientGraphBuilderConfiguration() weight_names_to_train = set() for initializer in original_model.graph.initializer: @@ -149,6 +150,26 @@ for output in original_model.graph.output: output_names.add(output.name) config.output_names = output_names +models = [onnx.load_model_from_string(model_as_string) for model_as_string in C.ModuleGradientGraphBuilder().build_and_split(original_model.SerializeToString(), config)] +onnx.save(models[0], 'bert_gradient_graph.onnx') +onnx.save(models[1], 'bert_forward.onnx') +onnx.save(models[2], 'bert_backward.onnx') + + +#BERT with loss +original_model = onnx.load('bert-tiny-loss.onnx') +config = C.ModuleGradientGraphBuilderConfiguration() +weight_names_to_train = set() +for initializer in original_model.graph.initializer: + if initializer.name.startswith('bert.') or initializer.name.startswith('cls.'): + weight_names_to_train.add(initializer.name) +config.weight_names_to_train = weight_names_to_train +output_names = set() +output_names.add('total_loss') +#for output in original_model.graph.output: +# output_names.add(output.name) +config.output_names = output_names + models = [onnx.load_model_from_string(model_as_string) for model_as_string in C.ModuleGradientGraphBuilder().build_and_split(original_model.SerializeToString(), config)] onnx.save(models[0], 'bert_gradient_graph.onnx') onnx.save(models[1], 'bert_forward.onnx')