diff --git a/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc b/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc index e23f55a329..f6123cedc0 100644 --- a/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc +++ b/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc @@ -45,23 +45,39 @@ void FilterInitializers(Graph& graph, const std::unordered_set& inp } Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream, - const ModuleGradientGraphBuilderConfiguration& config, - std::vector& models_as_string) { + const ModuleGradientGraphBuilderConfiguration& config) { logger_ = &logging::LoggingManager::DefaultLogger(); // use default logger for now. ONNX_NAMESPACE::ModelProto model_proto; ORT_RETURN_IF_ERROR(Model::Load(model_istream, &model_proto)); ORT_RETURN_IF_ERROR(Model::Load(model_proto, model_, nullptr, *logger_)); ORT_RETURN_IF_ERROR(model_->MainGraph().Resolve()); + // Handle original model inputs, outputs and trainable initializers. + const std::vector& graph_inputs = model_->MainGraph().GetInputsIncludingInitializers(); + for (auto& node_arg : graph_inputs) { + split_graphs_info_.user_input_names.emplace_back(node_arg->Name()); + } + + const std::vector& graph_outputs = model_->MainGraph().GetOutputs(); + for (auto& node_arg : graph_outputs) { + split_graphs_info_.user_output_names.emplace_back(node_arg->Name()); + } + + split_graphs_info_.initializer_names_to_train.assign(config.initializer_names_to_train.begin(), config.initializer_names_to_train.end()); + // Register and apply transformers for pre-training. const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration graph_transformer_config{}; GraphTransformerManager graph_transformation_mgr{2}; std::unique_ptr cpu_execution_provider = onnxruntime::make_unique(CPUExecutionProviderInfo()); + std::unordered_set x_node_arg_names; + std::set_union(config.initializer_names_to_train.begin(), config.initializer_names_to_train.end(), + config.input_names_require_grad.begin(), config.input_names_require_grad.end(), + std::inserter(x_node_arg_names, x_node_arg_names.begin())); auto add_transformers = [&](TransformerLevel level) { auto transformers_to_register = transformer_utils::GeneratePreTrainingTransformers( - level, config.weight_names_to_train, graph_transformer_config, *cpu_execution_provider, {}); + level, x_node_arg_names, graph_transformer_config, *cpu_execution_provider, {}); for (auto& entry : transformers_to_register) { graph_transformation_mgr.Register(std::move(entry), level); } @@ -85,12 +101,9 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream, GradientGraphConfiguration gradient_graph_config{}; gradient_graph_config.use_invertible_layernorm_grad = config.use_invertible_layernorm_grad; gradient_graph_config.set_gradients_as_graph_outputs = config.set_gradients_as_graph_outputs; - std::unordered_set x_node_arg_names; - std::set_union(config.weight_names_to_train.begin(), config.weight_names_to_train.end(), - config.input_names_require_grad.begin(), config.input_names_require_grad.end(), - std::inserter(x_node_arg_names, x_node_arg_names.begin())); + std::unordered_set y_node_arg_names(split_graphs_info_.user_output_names.begin(), split_graphs_info_.user_output_names.end()); GradientGraphBuilder grad_graph_builder(&model_->MainGraph(), - config.output_names, + y_node_arg_names, x_node_arg_names, "", // not support loss name for now. gradient_graph_config, @@ -108,44 +121,38 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream, GetInputAndOutputNames(node, input_names, output_names); } - const std::vector& gradient_graph_inputs = gradient_graph.GetInputsIncludingInitializers(); - std::vector graph_input_names; std::vector input_args; - for (auto& node_arg : gradient_graph_inputs) { - input_args.push_back(node_arg); - graph_input_names.push_back(node_arg->Name()); - } - - const std::vector& gradient_graph_outputs = gradient_graph.GetOutputs(); - std::vector graph_output_names; - std::vector output_args; - for (auto& node_arg : gradient_graph_outputs) { - output_args.push_back(node_arg); - graph_output_names.push_back(node_arg->Name()); + for (auto& input_name : split_graphs_info_.user_input_names) { + input_args.emplace_back(gradient_graph.GetNodeArg(input_name)); } // Add the entry points of gradients (normally loss_gard) to the graph inputs. Using the order of graph outputs. - for (const auto& output_name : graph_output_names) { - if (config.output_names.find(output_name) == config.output_names.end()) { - continue; - } - + for (const auto& output_name : split_graphs_info_.user_output_names) { std::string output_gradient_name = output_name + "_grad"; - if (input_names.find(output_gradient_name) != input_names.end() && - output_names.find(output_gradient_name) == output_names.end()) { - NodeArg* output_gradient_node_arg = gradient_graph.GetNodeArg(output_gradient_name); - output_gradient_node_arg->UpdateTypeAndShape(*gradient_graph.GetNodeArg(output_name), true, true, *logger_); - input_args.push_back(output_gradient_node_arg); + if (input_names.find(output_gradient_name) != input_names.end()) { + split_graphs_info_.user_output_grad_names.emplace_back(output_gradient_name); + // Only add to graph input when it's not an output of a node. + if (output_names.find(output_gradient_name) == output_names.end()) { + split_graphs_info_.backward_output_grad_names.emplace_back(output_gradient_name); + NodeArg* output_gradient_node_arg = gradient_graph.GetNodeArg(output_gradient_name); + output_gradient_node_arg->UpdateTypeAndShape(*gradient_graph.GetNodeArg(output_name), true, true, *logger_); + input_args.emplace_back(output_gradient_node_arg); + } } } gradient_graph.SetInputs(input_args); - // Add weight gradients to graph outputs. - for (const auto& weight_name : config.weight_names_to_train) { - std::string weight_gradient_name = weight_name + "_grad"; - if (output_names.find(weight_gradient_name) != output_names.end()) { - output_args.push_back(gradient_graph.GetNodeArg(weight_gradient_name)); + std::vector output_args; + for (auto& output_name : split_graphs_info_.user_output_names) { + output_args.emplace_back(gradient_graph.GetNodeArg(output_name)); + } + + // Add initializer gradients to graph outputs. + for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) { + std::string initializer_gradient_name = initializer_name + "_grad"; + if (output_names.find(initializer_gradient_name) != output_names.end()) { + output_args.emplace_back(gradient_graph.GetNodeArg(initializer_gradient_name)); } } @@ -153,7 +160,7 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream, for (const auto& input_name : config.input_names_require_grad) { std::string input_gradient_name = input_name + "_grad"; if (output_names.find(input_gradient_name) != output_names.end()) { - output_args.push_back(gradient_graph.GetNodeArg(input_gradient_name)); + output_args.emplace_back(gradient_graph.GetNodeArg(input_gradient_name)); } } @@ -167,33 +174,33 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream, ORT_RETURN_IF_ERROR(Model::Load(gradient_model_proto, backward_model_, nullptr, *logger_)); // Split the graph in the copies of gradient model. - ORT_RETURN_IF_ERROR(Split(config, graph_output_names)); - - // Serialize the models as output to frontend. - std::string gradient_model_str; - if (!model_->ToProto().SerializeToString(&gradient_model_str)) { - return Status(ONNXRUNTIME, FAIL, "Fail to serialize gradient model to string."); - } - - std::string forward_model_str; - if (!forward_model_->ToProto().SerializeToString(&forward_model_str)) { - return Status(ONNXRUNTIME, FAIL, "Fail to serialize forward model to string."); - } - - std::string backward_model_str; - if (!backward_model_->ToProto().SerializeToString(&backward_model_str)) { - return Status(ONNXRUNTIME, FAIL, "Fail to serialize backward model to string."); - } - - models_as_string.push_back(gradient_model_str); - models_as_string.push_back(forward_model_str); - models_as_string.push_back(backward_model_str); + ORT_RETURN_IF_ERROR(Split()); return Status::OK(); } -Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfiguration& config, - const std::vector& graph_output_names) { +std::string SerializeModel(const std::shared_ptr& model, const std::string& tag) { + std::string model_str; + if (!model->ToProto().SerializeToString(&model_str)) { + ORT_THROW("Fail to serialize", tag, "model to string."); + } + + return model_str; +} + +std::string ModuleGradientGraphBuilder::GetGradientModel() const { + return SerializeModel(model_, "gradient"); +} + +std::string ModuleGradientGraphBuilder::GetForwardModel() const { + return SerializeModel(forward_model_, "forward"); +} + +std::string ModuleGradientGraphBuilder::GetBackwardModel() const { + return SerializeModel(backward_model_, "backward"); +} + +Status ModuleGradientGraphBuilder::Split() { // Get forward model, also collect some information for backward model generation. Graph& forward_graph = forward_model_->MainGraph(); GraphViewer forward_graph_viewer(forward_graph); @@ -207,7 +214,7 @@ Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfigu auto& node = *forward_graph.GetNode(node_index); // Currently we are using node description to distinguish the forward and backward nodes. if (node.Description() == "Backward pass") { - forward_nodes_to_remove.push_back(&node); + forward_nodes_to_remove.emplace_back(&node); GetInputAndOutputNames(node, backward_input_names, backward_output_names); } else { GetInputAndOutputNames(node, forward_input_names, forward_output_names); @@ -224,39 +231,41 @@ Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfigu RemoveNodes(forward_graph, forward_nodes_to_remove); FilterInitializers(forward_graph, forward_input_names); - const std::vector& forward_graph_inputs = forward_graph.GetInputsIncludingInitializers(); + // All user inputs should be also part of the forward graph inputs. std::vector forward_input_args; - for (const NodeArg* node_arg : forward_graph_inputs) { - if (forward_input_names.find(node_arg->Name()) != forward_input_names.end()) { - forward_input_args.push_back(node_arg); - } + for (const auto& input_name : split_graphs_info_.user_input_names) { + forward_input_args.emplace_back(forward_graph.GetNodeArg(input_name)); } - // Add weights to forward graph inputs. - for (const auto& weight_name : config.weight_names_to_train) { - forward_input_args.push_back(forward_graph.GetNodeArg(weight_name)); + // Add initializers to forward graph inputs. + for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) { + forward_input_args.emplace_back(forward_graph.GetNodeArg(initializer_name)); } forward_graph.SetInputs(forward_input_args); + // All user outputs should be also part of the forward graph outputs. std::vector forward_output_args; - for (const auto& output_name : graph_output_names) { - forward_output_args.push_back(forward_graph.GetNodeArg(output_name)); + for (const auto& output_name : split_graphs_info_.user_output_names) { + forward_output_args.emplace_back(forward_graph.GetNodeArg(output_name)); } // Add intermediate args to forward graph outputs. for (const auto& intermediate_arg_name : intermediate_arg_names) { - // Ignore those duplicates. - if (config.output_names.find(intermediate_arg_name) == config.output_names.end()) { - forward_output_args.push_back(forward_graph.GetNodeArg(intermediate_arg_name)); + // Ignore the user outputs. + if (std::find(split_graphs_info_.user_output_names.begin(), split_graphs_info_.user_output_names.end(), intermediate_arg_name) + == split_graphs_info_.user_output_names.end()) { + split_graphs_info_.intermediate_tensor_names.emplace_back(intermediate_arg_name); + forward_output_args.emplace_back(forward_graph.GetNodeArg(intermediate_arg_name)); } } forward_graph.SetOutputs(forward_output_args); - // Resolve the forward graph, keep the weight initializers for now. + // Resolve the forward graph, keep the trainable initializers for now. Graph::ResolveOptions options; - options.initializer_names_to_preserve = &config.weight_names_to_train; + std::unordered_set initializer_names_to_train_set(split_graphs_info_.initializer_names_to_train.begin(), split_graphs_info_.initializer_names_to_train.end()); + options.initializer_names_to_preserve = &initializer_names_to_train_set; forward_graph.Resolve(options); // Get backward graph. @@ -267,44 +276,46 @@ Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfigu for (auto node_index : backward_node_topology_list) { auto& node = *backward_graph.GetNode(node_index); if (node.Description() != "Backward pass") { - backward_nodes_to_remove.push_back(&node); + backward_nodes_to_remove.emplace_back(&node); } } RemoveNodes(backward_graph, backward_nodes_to_remove); - const std::vector& backward_graph_inputs = backward_graph.GetInputsIncludingInitializers(); std::vector backward_input_args; - for (auto& node_arg : backward_graph_inputs) { + for (const auto& input_name : split_graphs_info_.user_input_names) { // Only takes those in the backward inputs. - if (backward_input_names.find(node_arg->Name()) != backward_input_names.end()) { - backward_input_args.push_back(node_arg); + if (backward_input_names.find(input_name) != backward_input_names.end()) { + split_graphs_info_.backward_user_input_names.emplace_back(input_name); + backward_input_args.emplace_back(backward_graph.GetNodeArg(input_name)); } } - // Add weight args to backward graph inputs if any node uses them. - for (const auto& weight_name : config.weight_names_to_train) { - // Weights will be inputs for backward graph. - if (backward_input_names.find(weight_name) != backward_input_names.end()) { - backward_input_args.push_back(backward_graph.GetNodeArg(weight_name)); - backward_graph.RemoveInitializedTensor(weight_name); + // Add initializer args to backward graph inputs if any node uses them. + for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) { + // Some initializers will be inputs for backward graph. + if (backward_input_names.find(initializer_name) != backward_input_names.end()) { + split_graphs_info_.backward_intializer_names_as_input.emplace_back(initializer_name); + backward_input_args.emplace_back(backward_graph.GetNodeArg(initializer_name)); + backward_graph.RemoveInitializedTensor(initializer_name); } } // Add intermediate args to backward graph inputs. - for (const auto& intermediate_arg_name : intermediate_arg_names) { + for (const auto& intermediate_arg_name : split_graphs_info_.intermediate_tensor_names) { NodeArg* intermediate_node_arg = backward_graph.GetNodeArg(intermediate_arg_name); intermediate_node_arg->UpdateTypeAndShape(*forward_graph.GetNodeArg(intermediate_arg_name), true, true, *logger_); - backward_input_args.push_back(intermediate_node_arg); + backward_input_args.emplace_back(intermediate_node_arg); } backward_graph.SetInputs(backward_input_args); + // Exclude user outputs from the backward graph. const std::vector& backward_graph_outputs = backward_graph.GetOutputs(); std::vector backward_output_args; for (auto& node_arg : backward_graph_outputs) { if (backward_output_names.find(node_arg->Name()) != backward_output_names.end()) { - backward_output_args.push_back(node_arg); + backward_output_args.emplace_back(node_arg); } } diff --git a/orttraining/orttraining/core/framework/module_gradient_graph_builder.h b/orttraining/orttraining/core/framework/module_gradient_graph_builder.h index e01635d8d0..344f47a7f1 100644 --- a/orttraining/orttraining/core/framework/module_gradient_graph_builder.h +++ b/orttraining/orttraining/core/framework/module_gradient_graph_builder.h @@ -13,34 +13,54 @@ namespace training { * The training configuration options. */ struct ModuleGradientGraphBuilderConfiguration { -// The names of the weights to train. -std::unordered_set weight_names_to_train{}; -// The names of inputs that require gradient. -std::unordered_set input_names_require_grad{}; -// The names of module outputs. -std::unordered_set output_names{}; + // The names of the weights to train. + std::vector initializer_names_to_train{}; + // The names of inputs that require gradient. + std::vector input_names_require_grad{}; -// Gradient graph configuration. -bool use_invertible_layernorm_grad = false; -bool set_gradients_as_graph_outputs = false; + // Gradient graph configuration. + bool use_invertible_layernorm_grad = false; + bool set_gradients_as_graph_outputs = false; -// TODO: add GraphTransformerConfiguration -// TODO: add mixed precision config -// TODO: do we need to support graph with loss? + // TODO: add GraphTransformerConfiguration + // TODO: add mixed precision config + // TODO: do we need to support graph with loss? +}; + +/** + * The information of split graphs for frontend. + */ +struct SplitGraphsInfo { + std::vector user_input_names{}; + std::vector initializer_names_to_train{}; + std::vector user_output_names{}; + std::vector backward_user_input_names{}; + std::vector backward_intializer_names_as_input{}; + std::vector intermediate_tensor_names{}; + std::vector user_output_grad_names{}; + std::vector backward_output_grad_names{}; }; class ModuleGradientGraphBuilder { public: Status BuildAndSplit(std::istream& model_istream, - const ModuleGradientGraphBuilderConfiguration& config, - std::vector& models_as_string); + const ModuleGradientGraphBuilderConfiguration& config); + + std::string GetGradientModel() const; + std::string GetForwardModel() const; + std::string GetBackwardModel() const; + SplitGraphsInfo GetSplitGraphsInfo() const { + return split_graphs_info_; + } + private: - Status Split(const ModuleGradientGraphBuilderConfiguration& config, - const std::vector& graph_output_names); + Status Split(); std::shared_ptr model_; std::shared_ptr forward_model_; std::shared_ptr backward_model_; + SplitGraphsInfo split_graphs_info_; + const logging::Logger* logger_; }; diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index e3bebb2a85..9125eef5f4 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -357,11 +357,22 @@ void addObjectMethodsForTraining(py::module& m) { py::class_ module_gradient_graph_builder_config( m, "ModuleGradientGraphBuilderConfiguration", R"pbdoc(Configuration information for module gradient graph builder.)pbdoc"); module_gradient_graph_builder_config.def(py::init()) - .def_readwrite("weight_names_to_train", &ModuleGradientGraphBuilderConfiguration::weight_names_to_train) + .def_readwrite("initializer_names_to_train", &ModuleGradientGraphBuilderConfiguration::initializer_names_to_train) .def_readwrite("input_names_require_grad", &ModuleGradientGraphBuilderConfiguration::input_names_require_grad) - .def_readwrite("output_names", &ModuleGradientGraphBuilderConfiguration::output_names) .def_readwrite("use_invertible_layernorm_grad", &ModuleGradientGraphBuilderConfiguration::use_invertible_layernorm_grad) .def_readwrite("set_gradients_as_graph_outputs", &ModuleGradientGraphBuilderConfiguration::set_gradients_as_graph_outputs); + + py::class_ split_graphs_info( + m, "SplitGraphsInfo", R"pbdoc(The information of split graphs for frontend.)pbdoc"); + split_graphs_info.def(py::init()) + .def_readwrite("user_input_names", &SplitGraphsInfo::user_input_names) + .def_readwrite("initializer_names_to_train", &SplitGraphsInfo::initializer_names_to_train) + .def_readwrite("user_output_names", &SplitGraphsInfo::user_output_names) + .def_readwrite("backward_user_input_names", &SplitGraphsInfo::backward_user_input_names) + .def_readwrite("backward_intializer_names_as_input", &SplitGraphsInfo::backward_intializer_names_as_input) + .def_readwrite("intermediate_tensor_names", &SplitGraphsInfo::intermediate_tensor_names) + .def_readwrite("user_output_grad_names", &SplitGraphsInfo::user_output_grad_names) + .def_readwrite("backward_output_grad_names", &SplitGraphsInfo::backward_output_grad_names); py::class_ module_gradient_graph_builder(m, "ModuleGradientGraphBuilder"); module_gradient_graph_builder @@ -372,14 +383,19 @@ void addObjectMethodsForTraining(py::module& m) { const py::bytes& serialized_model, const ModuleGradientGraphBuilderConfiguration& config) { std::istringstream buffer(serialized_model); - std::vector models_as_string; - ORT_THROW_IF_ERROR(module_gradient_graph_builder->BuildAndSplit(buffer, config, models_as_string)); - std::vector models_as_bytes; - for (size_t i = 0; i < 3; i++) { - models_as_bytes.push_back(py::bytes(models_as_string[i])); - } - - return models_as_bytes; + ORT_THROW_IF_ERROR(module_gradient_graph_builder->BuildAndSplit(buffer, config)); + }) + .def("get_gradient_model", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) { + return py::bytes(module_gradient_graph_builder->GetGradientModel()); + }) + .def("get_forward_model", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) { + return py::bytes(module_gradient_graph_builder->GetForwardModel()); + }) + .def("get_backward_model", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) { + return py::bytes(module_gradient_graph_builder->GetBackwardModel()); + }) + .def("get_split_graphs_info", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) { + return module_gradient_graph_builder->GetSplitGraphsInfo(); }); } } // namespace python diff --git a/samples/python/mnist/graph_spliter.py b/samples/python/mnist/graph_spliter.py index 8efea4967a..e278912dda 100644 --- a/samples/python/mnist/graph_spliter.py +++ b/samples/python/mnist/graph_spliter.py @@ -3,122 +3,23 @@ import copy from onnx import shape_inference from onnxruntime.capi import _pybind_state as C -def add_input_from_initializer(model, initializer, docstring=None): - new_input = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims, docstring) - model.graph.input.append(new_input) -def add_input(model, name, data_type = None, dims = None, docstring = None): - new_input = onnx.helper.make_tensor_value_info(name, data_type, dims, docstring) - model.graph.input.append(new_input) - -def add_output(model, name, data_type = None, docstring = None): - new_output = model.graph.value_info.add() - new_output.name = name - if data_type: - new_output.type.CopyFrom(data_type) - if docstring: - new_output.doc_string = docstring - model.graph.output.append(new_output) - -def remove_nodes(onnx_model, nodes_to_remove): - all_nodes = [] - for node in onnx_model.graph.node: - if node not in nodes_to_remove: - all_nodes.append(node) - - onnx_model.graph.ClearField('node') - onnx_model.graph.node.extend(all_nodes) - -def split_graph(onnx_model): - forward_graph_outputs = set() - backward_graph_inputs = set() - backward_graph_outputs = set() - # Get forward graph - forward_model = copy.deepcopy(onnx_model) - nodes_to_remove_from_forward_graph = [] - initializers = {} - for initializer in forward_model.graph.initializer: - initializers[initializer.name] = initializer - forward_graph_initializer_names = set() - for node in forward_model.graph.node: - if node.doc_string == 'Backward pass': - # nodes belongs to backward graph - nodes_to_remove_from_forward_graph.append(node) - for input in node.input: - backward_graph_inputs.add(input) - for output in node.output: - backward_graph_outputs.add(output) - else: - # nodes belogs to forward graph - for input in node.input: - if input in initializers: - forward_graph_initializer_names.add(input) - for output in node.output: - forward_graph_outputs.add(output) - - forward_model.graph.ClearField('initializer') - for initializer_name in forward_graph_initializer_names: - forward_model.graph.initializer.append(initializers[initializer_name]) - - # outputs from forward graph that are also inputs of backwoard graph need to be added as graph output. - for output in forward_graph_outputs: - if output in backward_graph_inputs: - add_output(forward_model, output) - - remove_nodes(forward_model, nodes_to_remove_from_forward_graph) - - # Get backward graph - tensor_elem_types = {} - infered_model = shape_inference.infer_shapes(onnx_model) - for value_info in infered_model.graph.value_info: - tensor_elem_types[value_info.name] = value_info.type.tensor_type.elem_type - - backward_model = copy.deepcopy(onnx_model) - initializers = {} - for initializer in backward_model.graph.initializer: - initializers[initializer.name] = initializer - - nodes_to_remove_from_backward_graph = [] - for node in backward_model.graph.node: - if node.doc_string != 'Backward pass': - nodes_to_remove_from_backward_graph.append(node) - - # gradient of forward graph output will be the input of backward graph - for output in backward_model.graph.output: - if output.name + '_grad' in backward_graph_inputs: - add_input(backward_model, output.name + '_grad', output.type.tensor_type.elem_type) - - backward_graph_initializer_names = set() - for input in backward_graph_inputs: - if input in forward_graph_outputs: - # inputs of backward graph that are also outputs from forward graph need to be added to backward graph input - add_input(backward_model, input, tensor_elem_types[input] if input in tensor_elem_types else 1) - elif input in forward_graph_initializer_names: - # inputs from forward graph initializers need to be added to backward graph input - add_input_from_initializer(backward_model, initializers[input]) - elif input in initializers: - backward_graph_initializer_names.add(input) - - backward_model.graph.ClearField('initializer') - for initializer_name in backward_graph_initializer_names: - backward_model.graph.initializer.append(initializers[initializer_name]) - - # add gradient output to backward graph output - # TODO: need to add gradient of graph input to backward graph output - new_backward_graph_outputs = set() - for output in backward_graph_outputs: - if output.endswith('_grad') and output[:-5] in forward_graph_initializer_names: - new_backward_graph_outputs.add(output) - - backward_model.graph.ClearField('output') - for output in new_backward_graph_outputs: - add_output(backward_model, output) - - remove_nodes(backward_model, nodes_to_remove_from_backward_graph) - - return forward_model, backward_model +def print_list(name, value): + print(name + ':', ', '.join(value)) +def dim_str(dim): + if dim.HasField('dim_value'): + return str(dim.dim_value) + elif dim.HasField('dim_param'): + return dim.dim_param + return 'n/a' + +def print_type(name, type): + print('[' + name + ']', 'type:', type.tensor_type.elem_type, '| size:', '[' + ','.join([dim_str(d) for d in type.tensor_type.shape.dim]) + ']') + + +""" # MNIST original_model = onnx.load('mnist_original.onnx') config = C.ModuleGradientGraphBuilderConfiguration() @@ -137,7 +38,6 @@ onnx.save(models[1], 'mnist_forward.onnx') onnx.save(models[2], 'mnist_backward.onnx') -""" #BERT original_model = onnx.load('BertForSequenceClassification_full_training.onnx') config = C.ModuleGradientGraphBuilderConfiguration() @@ -154,27 +54,67 @@ models = [onnx.load_model_from_string(model_as_string) for model_as_string in C. onnx.save(models[0], 'bert_gradient_graph.onnx') onnx.save(models[1], 'bert_forward.onnx') onnx.save(models[2], 'bert_backward.onnx') - +""" #BERT with loss original_model = onnx.load('bert-tiny-loss.onnx') config = C.ModuleGradientGraphBuilderConfiguration() -weight_names_to_train = set() +initializer_names_to_train = [] for initializer in original_model.graph.initializer: if initializer.name.startswith('bert.') or initializer.name.startswith('cls.'): - weight_names_to_train.add(initializer.name) -config.weight_names_to_train = weight_names_to_train -input_names_require_grad = set() -input_names_require_grad.add('input3') + initializer_names_to_train.append(initializer.name) +config.initializer_names_to_train = initializer_names_to_train +input_names_require_grad = [] +input_names_require_grad.append('input3') config.input_names_require_grad = input_names_require_grad -output_names = set() -#output_names.add('total_loss') -for output in original_model.graph.output: - output_names.add(output.name) -config.output_names = output_names -models = [onnx.load_model_from_string(model_as_string) for model_as_string in C.ModuleGradientGraphBuilder().build_and_split(original_model.SerializeToString(), config)] -onnx.save(models[0], 'bert_gradient_graph.onnx') -onnx.save(models[1], 'bert_forward.onnx') -onnx.save(models[2], 'bert_backward.onnx') -""" +module_gradient_graph_builder = C.ModuleGradientGraphBuilder() +module_gradient_graph_builder.build_and_split(original_model.SerializeToString(), config) + +forward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_forward_model()) +backward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_backward_model()) +onnx.save(onnx.load_model_from_string(module_gradient_graph_builder.get_gradient_model()), 'bert_gradient_graph.onnx') +onnx.save(forward_model, 'bert_forward.onnx') +onnx.save(backward_model, 'bert_backward.onnx') + +split_graphs_info = module_gradient_graph_builder.get_split_graphs_info() +print_list('user_input_names', split_graphs_info.user_input_names) +print_list('initializer_names_to_train', split_graphs_info.initializer_names_to_train) +print_list('user_output_names', split_graphs_info.user_output_names) +print_list('backward_user_input_names', split_graphs_info.backward_user_input_names) +print_list('backward_intializer_names_as_input', split_graphs_info.backward_intializer_names_as_input) +print_list('intermediate_tensor_names', split_graphs_info.intermediate_tensor_names) +print_list('user_output_grad_names', split_graphs_info.user_output_grad_names) +print_list('backward_output_grad_names', split_graphs_info.backward_output_grad_names) + +type_map = {} +for name in split_graphs_info.user_input_names: + type_map[name] = None +for name in split_graphs_info.initializer_names_to_train: + type_map[name] = None +for name in split_graphs_info.user_output_names: + type_map[name] = None +for name in split_graphs_info.backward_user_input_names: + type_map[name] = None +for name in split_graphs_info.backward_intializer_names_as_input: + type_map[name] = None +for name in split_graphs_info.intermediate_tensor_names: + type_map[name] = None +for name in split_graphs_info.user_output_grad_names: + type_map[name] = None +for name in split_graphs_info.backward_output_grad_names: + type_map[name] = None + +for input in forward_model.graph.input: + if input.name in type_map and type_map[input.name] is None: + type_map[input.name] = input.type + +for output in forward_model.graph.output: + if output.name in type_map and type_map[output.name] is None: + type_map[output.name] = output.type + output_grad_name = output.name + '_grad' + if output_grad_name in type_map and type_map[output_grad_name] is None: + type_map[output_grad_name] = output.type + +for key, value in type_map.items(): + print_type(key, value)