split graphs info

This commit is contained in:
Vincent Wang 2020-11-11 10:20:38 +00:00 committed by Thiago Crepaldi
parent cfd57c0136
commit f6a8d2aa5f
4 changed files with 233 additions and 246 deletions

View file

@ -45,23 +45,39 @@ void FilterInitializers(Graph& graph, const std::unordered_set<std::string>& inp
}
Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
const ModuleGradientGraphBuilderConfiguration& config,
std::vector<std::string>& models_as_string) {
const ModuleGradientGraphBuilderConfiguration& config) {
logger_ = &logging::LoggingManager::DefaultLogger(); // use default logger for now.
ONNX_NAMESPACE::ModelProto model_proto;
ORT_RETURN_IF_ERROR(Model::Load(model_istream, &model_proto));
ORT_RETURN_IF_ERROR(Model::Load(model_proto, model_, nullptr, *logger_));
ORT_RETURN_IF_ERROR(model_->MainGraph().Resolve());
// Handle original model inputs, outputs and trainable initializers.
const std::vector<const NodeArg*>& graph_inputs = model_->MainGraph().GetInputsIncludingInitializers();
for (auto& node_arg : graph_inputs) {
split_graphs_info_.user_input_names.emplace_back(node_arg->Name());
}
const std::vector<const NodeArg*>& graph_outputs = model_->MainGraph().GetOutputs();
for (auto& node_arg : graph_outputs) {
split_graphs_info_.user_output_names.emplace_back(node_arg->Name());
}
split_graphs_info_.initializer_names_to_train.assign(config.initializer_names_to_train.begin(), config.initializer_names_to_train.end());
// Register and apply transformers for pre-training.
const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration graph_transformer_config{};
GraphTransformerManager graph_transformation_mgr{2};
std::unique_ptr<CPUExecutionProvider> cpu_execution_provider =
onnxruntime::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
std::unordered_set<std::string> x_node_arg_names;
std::set_union(config.initializer_names_to_train.begin(), config.initializer_names_to_train.end(),
config.input_names_require_grad.begin(), config.input_names_require_grad.end(),
std::inserter(x_node_arg_names, x_node_arg_names.begin()));
auto add_transformers = [&](TransformerLevel level) {
auto transformers_to_register = transformer_utils::GeneratePreTrainingTransformers(
level, config.weight_names_to_train, graph_transformer_config, *cpu_execution_provider, {});
level, x_node_arg_names, graph_transformer_config, *cpu_execution_provider, {});
for (auto& entry : transformers_to_register) {
graph_transformation_mgr.Register(std::move(entry), level);
}
@ -85,12 +101,9 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
GradientGraphConfiguration gradient_graph_config{};
gradient_graph_config.use_invertible_layernorm_grad = config.use_invertible_layernorm_grad;
gradient_graph_config.set_gradients_as_graph_outputs = config.set_gradients_as_graph_outputs;
std::unordered_set<std::string> x_node_arg_names;
std::set_union(config.weight_names_to_train.begin(), config.weight_names_to_train.end(),
config.input_names_require_grad.begin(), config.input_names_require_grad.end(),
std::inserter(x_node_arg_names, x_node_arg_names.begin()));
std::unordered_set<std::string> y_node_arg_names(split_graphs_info_.user_output_names.begin(), split_graphs_info_.user_output_names.end());
GradientGraphBuilder grad_graph_builder(&model_->MainGraph(),
config.output_names,
y_node_arg_names,
x_node_arg_names,
"", // not support loss name for now.
gradient_graph_config,
@ -108,44 +121,38 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
GetInputAndOutputNames(node, input_names, output_names);
}
const std::vector<const NodeArg*>& gradient_graph_inputs = gradient_graph.GetInputsIncludingInitializers();
std::vector<std::string> graph_input_names;
std::vector<const NodeArg*> input_args;
for (auto& node_arg : gradient_graph_inputs) {
input_args.push_back(node_arg);
graph_input_names.push_back(node_arg->Name());
}
const std::vector<const NodeArg*>& gradient_graph_outputs = gradient_graph.GetOutputs();
std::vector<std::string> graph_output_names;
std::vector<const NodeArg*> output_args;
for (auto& node_arg : gradient_graph_outputs) {
output_args.push_back(node_arg);
graph_output_names.push_back(node_arg->Name());
for (auto& input_name : split_graphs_info_.user_input_names) {
input_args.emplace_back(gradient_graph.GetNodeArg(input_name));
}
// Add the entry points of gradients (normally loss_gard) to the graph inputs. Using the order of graph outputs.
for (const auto& output_name : graph_output_names) {
if (config.output_names.find(output_name) == config.output_names.end()) {
continue;
}
for (const auto& output_name : split_graphs_info_.user_output_names) {
std::string output_gradient_name = output_name + "_grad";
if (input_names.find(output_gradient_name) != input_names.end() &&
output_names.find(output_gradient_name) == output_names.end()) {
NodeArg* output_gradient_node_arg = gradient_graph.GetNodeArg(output_gradient_name);
output_gradient_node_arg->UpdateTypeAndShape(*gradient_graph.GetNodeArg(output_name), true, true, *logger_);
input_args.push_back(output_gradient_node_arg);
if (input_names.find(output_gradient_name) != input_names.end()) {
split_graphs_info_.user_output_grad_names.emplace_back(output_gradient_name);
// Only add to graph input when it's not an output of a node.
if (output_names.find(output_gradient_name) == output_names.end()) {
split_graphs_info_.backward_output_grad_names.emplace_back(output_gradient_name);
NodeArg* output_gradient_node_arg = gradient_graph.GetNodeArg(output_gradient_name);
output_gradient_node_arg->UpdateTypeAndShape(*gradient_graph.GetNodeArg(output_name), true, true, *logger_);
input_args.emplace_back(output_gradient_node_arg);
}
}
}
gradient_graph.SetInputs(input_args);
// Add weight gradients to graph outputs.
for (const auto& weight_name : config.weight_names_to_train) {
std::string weight_gradient_name = weight_name + "_grad";
if (output_names.find(weight_gradient_name) != output_names.end()) {
output_args.push_back(gradient_graph.GetNodeArg(weight_gradient_name));
std::vector<const NodeArg*> output_args;
for (auto& output_name : split_graphs_info_.user_output_names) {
output_args.emplace_back(gradient_graph.GetNodeArg(output_name));
}
// Add initializer gradients to graph outputs.
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
std::string initializer_gradient_name = initializer_name + "_grad";
if (output_names.find(initializer_gradient_name) != output_names.end()) {
output_args.emplace_back(gradient_graph.GetNodeArg(initializer_gradient_name));
}
}
@ -153,7 +160,7 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
for (const auto& input_name : config.input_names_require_grad) {
std::string input_gradient_name = input_name + "_grad";
if (output_names.find(input_gradient_name) != output_names.end()) {
output_args.push_back(gradient_graph.GetNodeArg(input_gradient_name));
output_args.emplace_back(gradient_graph.GetNodeArg(input_gradient_name));
}
}
@ -167,33 +174,33 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
ORT_RETURN_IF_ERROR(Model::Load(gradient_model_proto, backward_model_, nullptr, *logger_));
// Split the graph in the copies of gradient model.
ORT_RETURN_IF_ERROR(Split(config, graph_output_names));
// Serialize the models as output to frontend.
std::string gradient_model_str;
if (!model_->ToProto().SerializeToString(&gradient_model_str)) {
return Status(ONNXRUNTIME, FAIL, "Fail to serialize gradient model to string.");
}
std::string forward_model_str;
if (!forward_model_->ToProto().SerializeToString(&forward_model_str)) {
return Status(ONNXRUNTIME, FAIL, "Fail to serialize forward model to string.");
}
std::string backward_model_str;
if (!backward_model_->ToProto().SerializeToString(&backward_model_str)) {
return Status(ONNXRUNTIME, FAIL, "Fail to serialize backward model to string.");
}
models_as_string.push_back(gradient_model_str);
models_as_string.push_back(forward_model_str);
models_as_string.push_back(backward_model_str);
ORT_RETURN_IF_ERROR(Split());
return Status::OK();
}
Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfiguration& config,
const std::vector<std::string>& graph_output_names) {
std::string SerializeModel(const std::shared_ptr<onnxruntime::Model>& model, const std::string& tag) {
std::string model_str;
if (!model->ToProto().SerializeToString(&model_str)) {
ORT_THROW("Fail to serialize", tag, "model to string.");
}
return model_str;
}
std::string ModuleGradientGraphBuilder::GetGradientModel() const {
return SerializeModel(model_, "gradient");
}
std::string ModuleGradientGraphBuilder::GetForwardModel() const {
return SerializeModel(forward_model_, "forward");
}
std::string ModuleGradientGraphBuilder::GetBackwardModel() const {
return SerializeModel(backward_model_, "backward");
}
Status ModuleGradientGraphBuilder::Split() {
// Get forward model, also collect some information for backward model generation.
Graph& forward_graph = forward_model_->MainGraph();
GraphViewer forward_graph_viewer(forward_graph);
@ -207,7 +214,7 @@ Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfigu
auto& node = *forward_graph.GetNode(node_index);
// Currently we are using node description to distinguish the forward and backward nodes.
if (node.Description() == "Backward pass") {
forward_nodes_to_remove.push_back(&node);
forward_nodes_to_remove.emplace_back(&node);
GetInputAndOutputNames(node, backward_input_names, backward_output_names);
} else {
GetInputAndOutputNames(node, forward_input_names, forward_output_names);
@ -224,39 +231,41 @@ Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfigu
RemoveNodes(forward_graph, forward_nodes_to_remove);
FilterInitializers(forward_graph, forward_input_names);
const std::vector<const NodeArg*>& forward_graph_inputs = forward_graph.GetInputsIncludingInitializers();
// All user inputs should be also part of the forward graph inputs.
std::vector<const NodeArg*> forward_input_args;
for (const NodeArg* node_arg : forward_graph_inputs) {
if (forward_input_names.find(node_arg->Name()) != forward_input_names.end()) {
forward_input_args.push_back(node_arg);
}
for (const auto& input_name : split_graphs_info_.user_input_names) {
forward_input_args.emplace_back(forward_graph.GetNodeArg(input_name));
}
// Add weights to forward graph inputs.
for (const auto& weight_name : config.weight_names_to_train) {
forward_input_args.push_back(forward_graph.GetNodeArg(weight_name));
// Add initializers to forward graph inputs.
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
forward_input_args.emplace_back(forward_graph.GetNodeArg(initializer_name));
}
forward_graph.SetInputs(forward_input_args);
// All user outputs should be also part of the forward graph outputs.
std::vector<const NodeArg*> forward_output_args;
for (const auto& output_name : graph_output_names) {
forward_output_args.push_back(forward_graph.GetNodeArg(output_name));
for (const auto& output_name : split_graphs_info_.user_output_names) {
forward_output_args.emplace_back(forward_graph.GetNodeArg(output_name));
}
// Add intermediate args to forward graph outputs.
for (const auto& intermediate_arg_name : intermediate_arg_names) {
// Ignore those duplicates.
if (config.output_names.find(intermediate_arg_name) == config.output_names.end()) {
forward_output_args.push_back(forward_graph.GetNodeArg(intermediate_arg_name));
// Ignore the user outputs.
if (std::find(split_graphs_info_.user_output_names.begin(), split_graphs_info_.user_output_names.end(), intermediate_arg_name)
== split_graphs_info_.user_output_names.end()) {
split_graphs_info_.intermediate_tensor_names.emplace_back(intermediate_arg_name);
forward_output_args.emplace_back(forward_graph.GetNodeArg(intermediate_arg_name));
}
}
forward_graph.SetOutputs(forward_output_args);
// Resolve the forward graph, keep the weight initializers for now.
// Resolve the forward graph, keep the trainable initializers for now.
Graph::ResolveOptions options;
options.initializer_names_to_preserve = &config.weight_names_to_train;
std::unordered_set<std::string> initializer_names_to_train_set(split_graphs_info_.initializer_names_to_train.begin(), split_graphs_info_.initializer_names_to_train.end());
options.initializer_names_to_preserve = &initializer_names_to_train_set;
forward_graph.Resolve(options);
// Get backward graph.
@ -267,44 +276,46 @@ Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfigu
for (auto node_index : backward_node_topology_list) {
auto& node = *backward_graph.GetNode(node_index);
if (node.Description() != "Backward pass") {
backward_nodes_to_remove.push_back(&node);
backward_nodes_to_remove.emplace_back(&node);
}
}
RemoveNodes(backward_graph, backward_nodes_to_remove);
const std::vector<const NodeArg*>& backward_graph_inputs = backward_graph.GetInputsIncludingInitializers();
std::vector<const NodeArg*> backward_input_args;
for (auto& node_arg : backward_graph_inputs) {
for (const auto& input_name : split_graphs_info_.user_input_names) {
// Only takes those in the backward inputs.
if (backward_input_names.find(node_arg->Name()) != backward_input_names.end()) {
backward_input_args.push_back(node_arg);
if (backward_input_names.find(input_name) != backward_input_names.end()) {
split_graphs_info_.backward_user_input_names.emplace_back(input_name);
backward_input_args.emplace_back(backward_graph.GetNodeArg(input_name));
}
}
// Add weight args to backward graph inputs if any node uses them.
for (const auto& weight_name : config.weight_names_to_train) {
// Weights will be inputs for backward graph.
if (backward_input_names.find(weight_name) != backward_input_names.end()) {
backward_input_args.push_back(backward_graph.GetNodeArg(weight_name));
backward_graph.RemoveInitializedTensor(weight_name);
// Add initializer args to backward graph inputs if any node uses them.
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
// Some initializers will be inputs for backward graph.
if (backward_input_names.find(initializer_name) != backward_input_names.end()) {
split_graphs_info_.backward_intializer_names_as_input.emplace_back(initializer_name);
backward_input_args.emplace_back(backward_graph.GetNodeArg(initializer_name));
backward_graph.RemoveInitializedTensor(initializer_name);
}
}
// Add intermediate args to backward graph inputs.
for (const auto& intermediate_arg_name : intermediate_arg_names) {
for (const auto& intermediate_arg_name : split_graphs_info_.intermediate_tensor_names) {
NodeArg* intermediate_node_arg = backward_graph.GetNodeArg(intermediate_arg_name);
intermediate_node_arg->UpdateTypeAndShape(*forward_graph.GetNodeArg(intermediate_arg_name), true, true, *logger_);
backward_input_args.push_back(intermediate_node_arg);
backward_input_args.emplace_back(intermediate_node_arg);
}
backward_graph.SetInputs(backward_input_args);
// Exclude user outputs from the backward graph.
const std::vector<const NodeArg*>& backward_graph_outputs = backward_graph.GetOutputs();
std::vector<const NodeArg*> backward_output_args;
for (auto& node_arg : backward_graph_outputs) {
if (backward_output_names.find(node_arg->Name()) != backward_output_names.end()) {
backward_output_args.push_back(node_arg);
backward_output_args.emplace_back(node_arg);
}
}

View file

@ -13,34 +13,54 @@ namespace training {
* The training configuration options.
*/
struct ModuleGradientGraphBuilderConfiguration {
// The names of the weights to train.
std::unordered_set<std::string> weight_names_to_train{};
// The names of inputs that require gradient.
std::unordered_set<std::string> input_names_require_grad{};
// The names of module outputs.
std::unordered_set<std::string> output_names{};
// The names of the weights to train.
std::vector<std::string> initializer_names_to_train{};
// The names of inputs that require gradient.
std::vector<std::string> input_names_require_grad{};
// Gradient graph configuration.
bool use_invertible_layernorm_grad = false;
bool set_gradients_as_graph_outputs = false;
// Gradient graph configuration.
bool use_invertible_layernorm_grad = false;
bool set_gradients_as_graph_outputs = false;
// TODO: add GraphTransformerConfiguration
// TODO: add mixed precision config
// TODO: do we need to support graph with loss?
// TODO: add GraphTransformerConfiguration
// TODO: add mixed precision config
// TODO: do we need to support graph with loss?
};
/**
* The information of split graphs for frontend.
*/
struct SplitGraphsInfo {
std::vector<std::string> user_input_names{};
std::vector<std::string> initializer_names_to_train{};
std::vector<std::string> user_output_names{};
std::vector<std::string> backward_user_input_names{};
std::vector<std::string> backward_intializer_names_as_input{};
std::vector<std::string> intermediate_tensor_names{};
std::vector<std::string> user_output_grad_names{};
std::vector<std::string> backward_output_grad_names{};
};
class ModuleGradientGraphBuilder {
public:
Status BuildAndSplit(std::istream& model_istream,
const ModuleGradientGraphBuilderConfiguration& config,
std::vector<std::string>& models_as_string);
const ModuleGradientGraphBuilderConfiguration& config);
std::string GetGradientModel() const;
std::string GetForwardModel() const;
std::string GetBackwardModel() const;
SplitGraphsInfo GetSplitGraphsInfo() const {
return split_graphs_info_;
}
private:
Status Split(const ModuleGradientGraphBuilderConfiguration& config,
const std::vector<std::string>& graph_output_names);
Status Split();
std::shared_ptr<onnxruntime::Model> model_;
std::shared_ptr<onnxruntime::Model> forward_model_;
std::shared_ptr<onnxruntime::Model> backward_model_;
SplitGraphsInfo split_graphs_info_;
const logging::Logger* logger_;
};

View file

@ -357,11 +357,22 @@ void addObjectMethodsForTraining(py::module& m) {
py::class_<ModuleGradientGraphBuilderConfiguration> module_gradient_graph_builder_config(
m, "ModuleGradientGraphBuilderConfiguration", R"pbdoc(Configuration information for module gradient graph builder.)pbdoc");
module_gradient_graph_builder_config.def(py::init())
.def_readwrite("weight_names_to_train", &ModuleGradientGraphBuilderConfiguration::weight_names_to_train)
.def_readwrite("initializer_names_to_train", &ModuleGradientGraphBuilderConfiguration::initializer_names_to_train)
.def_readwrite("input_names_require_grad", &ModuleGradientGraphBuilderConfiguration::input_names_require_grad)
.def_readwrite("output_names", &ModuleGradientGraphBuilderConfiguration::output_names)
.def_readwrite("use_invertible_layernorm_grad", &ModuleGradientGraphBuilderConfiguration::use_invertible_layernorm_grad)
.def_readwrite("set_gradients_as_graph_outputs", &ModuleGradientGraphBuilderConfiguration::set_gradients_as_graph_outputs);
py::class_<SplitGraphsInfo> split_graphs_info(
m, "SplitGraphsInfo", R"pbdoc(The information of split graphs for frontend.)pbdoc");
split_graphs_info.def(py::init())
.def_readwrite("user_input_names", &SplitGraphsInfo::user_input_names)
.def_readwrite("initializer_names_to_train", &SplitGraphsInfo::initializer_names_to_train)
.def_readwrite("user_output_names", &SplitGraphsInfo::user_output_names)
.def_readwrite("backward_user_input_names", &SplitGraphsInfo::backward_user_input_names)
.def_readwrite("backward_intializer_names_as_input", &SplitGraphsInfo::backward_intializer_names_as_input)
.def_readwrite("intermediate_tensor_names", &SplitGraphsInfo::intermediate_tensor_names)
.def_readwrite("user_output_grad_names", &SplitGraphsInfo::user_output_grad_names)
.def_readwrite("backward_output_grad_names", &SplitGraphsInfo::backward_output_grad_names);
py::class_<ModuleGradientGraphBuilder> module_gradient_graph_builder(m, "ModuleGradientGraphBuilder");
module_gradient_graph_builder
@ -372,14 +383,19 @@ void addObjectMethodsForTraining(py::module& m) {
const py::bytes& serialized_model,
const ModuleGradientGraphBuilderConfiguration& config) {
std::istringstream buffer(serialized_model);
std::vector<std::string> models_as_string;
ORT_THROW_IF_ERROR(module_gradient_graph_builder->BuildAndSplit(buffer, config, models_as_string));
std::vector<py::bytes> models_as_bytes;
for (size_t i = 0; i < 3; i++) {
models_as_bytes.push_back(py::bytes(models_as_string[i]));
}
return models_as_bytes;
ORT_THROW_IF_ERROR(module_gradient_graph_builder->BuildAndSplit(buffer, config));
})
.def("get_gradient_model", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
return py::bytes(module_gradient_graph_builder->GetGradientModel());
})
.def("get_forward_model", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
return py::bytes(module_gradient_graph_builder->GetForwardModel());
})
.def("get_backward_model", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
return py::bytes(module_gradient_graph_builder->GetBackwardModel());
})
.def("get_split_graphs_info", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
return module_gradient_graph_builder->GetSplitGraphsInfo();
});
}
} // namespace python

View file

@ -3,122 +3,23 @@ import copy
from onnx import shape_inference
from onnxruntime.capi import _pybind_state as C
def add_input_from_initializer(model, initializer, docstring=None):
new_input = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims, docstring)
model.graph.input.append(new_input)
def add_input(model, name, data_type = None, dims = None, docstring = None):
new_input = onnx.helper.make_tensor_value_info(name, data_type, dims, docstring)
model.graph.input.append(new_input)
def add_output(model, name, data_type = None, docstring = None):
new_output = model.graph.value_info.add()
new_output.name = name
if data_type:
new_output.type.CopyFrom(data_type)
if docstring:
new_output.doc_string = docstring
model.graph.output.append(new_output)
def remove_nodes(onnx_model, nodes_to_remove):
all_nodes = []
for node in onnx_model.graph.node:
if node not in nodes_to_remove:
all_nodes.append(node)
onnx_model.graph.ClearField('node')
onnx_model.graph.node.extend(all_nodes)
def split_graph(onnx_model):
forward_graph_outputs = set()
backward_graph_inputs = set()
backward_graph_outputs = set()
# Get forward graph
forward_model = copy.deepcopy(onnx_model)
nodes_to_remove_from_forward_graph = []
initializers = {}
for initializer in forward_model.graph.initializer:
initializers[initializer.name] = initializer
forward_graph_initializer_names = set()
for node in forward_model.graph.node:
if node.doc_string == 'Backward pass':
# nodes belongs to backward graph
nodes_to_remove_from_forward_graph.append(node)
for input in node.input:
backward_graph_inputs.add(input)
for output in node.output:
backward_graph_outputs.add(output)
else:
# nodes belogs to forward graph
for input in node.input:
if input in initializers:
forward_graph_initializer_names.add(input)
for output in node.output:
forward_graph_outputs.add(output)
forward_model.graph.ClearField('initializer')
for initializer_name in forward_graph_initializer_names:
forward_model.graph.initializer.append(initializers[initializer_name])
# outputs from forward graph that are also inputs of backwoard graph need to be added as graph output.
for output in forward_graph_outputs:
if output in backward_graph_inputs:
add_output(forward_model, output)
remove_nodes(forward_model, nodes_to_remove_from_forward_graph)
# Get backward graph
tensor_elem_types = {}
infered_model = shape_inference.infer_shapes(onnx_model)
for value_info in infered_model.graph.value_info:
tensor_elem_types[value_info.name] = value_info.type.tensor_type.elem_type
backward_model = copy.deepcopy(onnx_model)
initializers = {}
for initializer in backward_model.graph.initializer:
initializers[initializer.name] = initializer
nodes_to_remove_from_backward_graph = []
for node in backward_model.graph.node:
if node.doc_string != 'Backward pass':
nodes_to_remove_from_backward_graph.append(node)
# gradient of forward graph output will be the input of backward graph
for output in backward_model.graph.output:
if output.name + '_grad' in backward_graph_inputs:
add_input(backward_model, output.name + '_grad', output.type.tensor_type.elem_type)
backward_graph_initializer_names = set()
for input in backward_graph_inputs:
if input in forward_graph_outputs:
# inputs of backward graph that are also outputs from forward graph need to be added to backward graph input
add_input(backward_model, input, tensor_elem_types[input] if input in tensor_elem_types else 1)
elif input in forward_graph_initializer_names:
# inputs from forward graph initializers need to be added to backward graph input
add_input_from_initializer(backward_model, initializers[input])
elif input in initializers:
backward_graph_initializer_names.add(input)
backward_model.graph.ClearField('initializer')
for initializer_name in backward_graph_initializer_names:
backward_model.graph.initializer.append(initializers[initializer_name])
# add gradient output to backward graph output
# TODO: need to add gradient of graph input to backward graph output
new_backward_graph_outputs = set()
for output in backward_graph_outputs:
if output.endswith('_grad') and output[:-5] in forward_graph_initializer_names:
new_backward_graph_outputs.add(output)
backward_model.graph.ClearField('output')
for output in new_backward_graph_outputs:
add_output(backward_model, output)
remove_nodes(backward_model, nodes_to_remove_from_backward_graph)
return forward_model, backward_model
def print_list(name, value):
print(name + ':', ', '.join(value))
def dim_str(dim):
if dim.HasField('dim_value'):
return str(dim.dim_value)
elif dim.HasField('dim_param'):
return dim.dim_param
return 'n/a'
def print_type(name, type):
print('[' + name + ']', 'type:', type.tensor_type.elem_type, '| size:', '[' + ','.join([dim_str(d) for d in type.tensor_type.shape.dim]) + ']')
"""
# MNIST
original_model = onnx.load('mnist_original.onnx')
config = C.ModuleGradientGraphBuilderConfiguration()
@ -137,7 +38,6 @@ onnx.save(models[1], 'mnist_forward.onnx')
onnx.save(models[2], 'mnist_backward.onnx')
"""
#BERT
original_model = onnx.load('BertForSequenceClassification_full_training.onnx')
config = C.ModuleGradientGraphBuilderConfiguration()
@ -154,27 +54,67 @@ models = [onnx.load_model_from_string(model_as_string) for model_as_string in C.
onnx.save(models[0], 'bert_gradient_graph.onnx')
onnx.save(models[1], 'bert_forward.onnx')
onnx.save(models[2], 'bert_backward.onnx')
"""
#BERT with loss
original_model = onnx.load('bert-tiny-loss.onnx')
config = C.ModuleGradientGraphBuilderConfiguration()
weight_names_to_train = set()
initializer_names_to_train = []
for initializer in original_model.graph.initializer:
if initializer.name.startswith('bert.') or initializer.name.startswith('cls.'):
weight_names_to_train.add(initializer.name)
config.weight_names_to_train = weight_names_to_train
input_names_require_grad = set()
input_names_require_grad.add('input3')
initializer_names_to_train.append(initializer.name)
config.initializer_names_to_train = initializer_names_to_train
input_names_require_grad = []
input_names_require_grad.append('input3')
config.input_names_require_grad = input_names_require_grad
output_names = set()
#output_names.add('total_loss')
for output in original_model.graph.output:
output_names.add(output.name)
config.output_names = output_names
models = [onnx.load_model_from_string(model_as_string) for model_as_string in C.ModuleGradientGraphBuilder().build_and_split(original_model.SerializeToString(), config)]
onnx.save(models[0], 'bert_gradient_graph.onnx')
onnx.save(models[1], 'bert_forward.onnx')
onnx.save(models[2], 'bert_backward.onnx')
"""
module_gradient_graph_builder = C.ModuleGradientGraphBuilder()
module_gradient_graph_builder.build_and_split(original_model.SerializeToString(), config)
forward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_forward_model())
backward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_backward_model())
onnx.save(onnx.load_model_from_string(module_gradient_graph_builder.get_gradient_model()), 'bert_gradient_graph.onnx')
onnx.save(forward_model, 'bert_forward.onnx')
onnx.save(backward_model, 'bert_backward.onnx')
split_graphs_info = module_gradient_graph_builder.get_split_graphs_info()
print_list('user_input_names', split_graphs_info.user_input_names)
print_list('initializer_names_to_train', split_graphs_info.initializer_names_to_train)
print_list('user_output_names', split_graphs_info.user_output_names)
print_list('backward_user_input_names', split_graphs_info.backward_user_input_names)
print_list('backward_intializer_names_as_input', split_graphs_info.backward_intializer_names_as_input)
print_list('intermediate_tensor_names', split_graphs_info.intermediate_tensor_names)
print_list('user_output_grad_names', split_graphs_info.user_output_grad_names)
print_list('backward_output_grad_names', split_graphs_info.backward_output_grad_names)
type_map = {}
for name in split_graphs_info.user_input_names:
type_map[name] = None
for name in split_graphs_info.initializer_names_to_train:
type_map[name] = None
for name in split_graphs_info.user_output_names:
type_map[name] = None
for name in split_graphs_info.backward_user_input_names:
type_map[name] = None
for name in split_graphs_info.backward_intializer_names_as_input:
type_map[name] = None
for name in split_graphs_info.intermediate_tensor_names:
type_map[name] = None
for name in split_graphs_info.user_output_grad_names:
type_map[name] = None
for name in split_graphs_info.backward_output_grad_names:
type_map[name] = None
for input in forward_model.graph.input:
if input.name in type_map and type_map[input.name] is None:
type_map[input.name] = input.type
for output in forward_model.graph.output:
if output.name in type_map and type_map[output.name] is None:
type_map[output.name] = output.type
output_grad_name = output.name + '_grad'
if output_grad_name in type_map and type_map[output_grad_name] is None:
type_map[output_grad_name] = output.type
for key, value in type_map.items():
print_type(key, value)