mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
split graphs info
This commit is contained in:
parent
cfd57c0136
commit
f6a8d2aa5f
4 changed files with 233 additions and 246 deletions
|
|
@ -45,23 +45,39 @@ void FilterInitializers(Graph& graph, const std::unordered_set<std::string>& inp
|
|||
}
|
||||
|
||||
Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
|
||||
const ModuleGradientGraphBuilderConfiguration& config,
|
||||
std::vector<std::string>& models_as_string) {
|
||||
const ModuleGradientGraphBuilderConfiguration& config) {
|
||||
logger_ = &logging::LoggingManager::DefaultLogger(); // use default logger for now.
|
||||
ONNX_NAMESPACE::ModelProto model_proto;
|
||||
ORT_RETURN_IF_ERROR(Model::Load(model_istream, &model_proto));
|
||||
ORT_RETURN_IF_ERROR(Model::Load(model_proto, model_, nullptr, *logger_));
|
||||
ORT_RETURN_IF_ERROR(model_->MainGraph().Resolve());
|
||||
|
||||
// Handle original model inputs, outputs and trainable initializers.
|
||||
const std::vector<const NodeArg*>& graph_inputs = model_->MainGraph().GetInputsIncludingInitializers();
|
||||
for (auto& node_arg : graph_inputs) {
|
||||
split_graphs_info_.user_input_names.emplace_back(node_arg->Name());
|
||||
}
|
||||
|
||||
const std::vector<const NodeArg*>& graph_outputs = model_->MainGraph().GetOutputs();
|
||||
for (auto& node_arg : graph_outputs) {
|
||||
split_graphs_info_.user_output_names.emplace_back(node_arg->Name());
|
||||
}
|
||||
|
||||
split_graphs_info_.initializer_names_to_train.assign(config.initializer_names_to_train.begin(), config.initializer_names_to_train.end());
|
||||
|
||||
// Register and apply transformers for pre-training.
|
||||
const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration graph_transformer_config{};
|
||||
GraphTransformerManager graph_transformation_mgr{2};
|
||||
std::unique_ptr<CPUExecutionProvider> cpu_execution_provider =
|
||||
onnxruntime::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
|
||||
std::unordered_set<std::string> x_node_arg_names;
|
||||
std::set_union(config.initializer_names_to_train.begin(), config.initializer_names_to_train.end(),
|
||||
config.input_names_require_grad.begin(), config.input_names_require_grad.end(),
|
||||
std::inserter(x_node_arg_names, x_node_arg_names.begin()));
|
||||
auto add_transformers = [&](TransformerLevel level) {
|
||||
auto transformers_to_register = transformer_utils::GeneratePreTrainingTransformers(
|
||||
level, config.weight_names_to_train, graph_transformer_config, *cpu_execution_provider, {});
|
||||
level, x_node_arg_names, graph_transformer_config, *cpu_execution_provider, {});
|
||||
for (auto& entry : transformers_to_register) {
|
||||
graph_transformation_mgr.Register(std::move(entry), level);
|
||||
}
|
||||
|
|
@ -85,12 +101,9 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
|
|||
GradientGraphConfiguration gradient_graph_config{};
|
||||
gradient_graph_config.use_invertible_layernorm_grad = config.use_invertible_layernorm_grad;
|
||||
gradient_graph_config.set_gradients_as_graph_outputs = config.set_gradients_as_graph_outputs;
|
||||
std::unordered_set<std::string> x_node_arg_names;
|
||||
std::set_union(config.weight_names_to_train.begin(), config.weight_names_to_train.end(),
|
||||
config.input_names_require_grad.begin(), config.input_names_require_grad.end(),
|
||||
std::inserter(x_node_arg_names, x_node_arg_names.begin()));
|
||||
std::unordered_set<std::string> y_node_arg_names(split_graphs_info_.user_output_names.begin(), split_graphs_info_.user_output_names.end());
|
||||
GradientGraphBuilder grad_graph_builder(&model_->MainGraph(),
|
||||
config.output_names,
|
||||
y_node_arg_names,
|
||||
x_node_arg_names,
|
||||
"", // not support loss name for now.
|
||||
gradient_graph_config,
|
||||
|
|
@ -108,44 +121,38 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
|
|||
GetInputAndOutputNames(node, input_names, output_names);
|
||||
}
|
||||
|
||||
const std::vector<const NodeArg*>& gradient_graph_inputs = gradient_graph.GetInputsIncludingInitializers();
|
||||
std::vector<std::string> graph_input_names;
|
||||
std::vector<const NodeArg*> input_args;
|
||||
for (auto& node_arg : gradient_graph_inputs) {
|
||||
input_args.push_back(node_arg);
|
||||
graph_input_names.push_back(node_arg->Name());
|
||||
}
|
||||
|
||||
const std::vector<const NodeArg*>& gradient_graph_outputs = gradient_graph.GetOutputs();
|
||||
std::vector<std::string> graph_output_names;
|
||||
std::vector<const NodeArg*> output_args;
|
||||
for (auto& node_arg : gradient_graph_outputs) {
|
||||
output_args.push_back(node_arg);
|
||||
graph_output_names.push_back(node_arg->Name());
|
||||
for (auto& input_name : split_graphs_info_.user_input_names) {
|
||||
input_args.emplace_back(gradient_graph.GetNodeArg(input_name));
|
||||
}
|
||||
|
||||
// Add the entry points of gradients (normally loss_gard) to the graph inputs. Using the order of graph outputs.
|
||||
for (const auto& output_name : graph_output_names) {
|
||||
if (config.output_names.find(output_name) == config.output_names.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const auto& output_name : split_graphs_info_.user_output_names) {
|
||||
std::string output_gradient_name = output_name + "_grad";
|
||||
if (input_names.find(output_gradient_name) != input_names.end() &&
|
||||
output_names.find(output_gradient_name) == output_names.end()) {
|
||||
NodeArg* output_gradient_node_arg = gradient_graph.GetNodeArg(output_gradient_name);
|
||||
output_gradient_node_arg->UpdateTypeAndShape(*gradient_graph.GetNodeArg(output_name), true, true, *logger_);
|
||||
input_args.push_back(output_gradient_node_arg);
|
||||
if (input_names.find(output_gradient_name) != input_names.end()) {
|
||||
split_graphs_info_.user_output_grad_names.emplace_back(output_gradient_name);
|
||||
// Only add to graph input when it's not an output of a node.
|
||||
if (output_names.find(output_gradient_name) == output_names.end()) {
|
||||
split_graphs_info_.backward_output_grad_names.emplace_back(output_gradient_name);
|
||||
NodeArg* output_gradient_node_arg = gradient_graph.GetNodeArg(output_gradient_name);
|
||||
output_gradient_node_arg->UpdateTypeAndShape(*gradient_graph.GetNodeArg(output_name), true, true, *logger_);
|
||||
input_args.emplace_back(output_gradient_node_arg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
gradient_graph.SetInputs(input_args);
|
||||
|
||||
// Add weight gradients to graph outputs.
|
||||
for (const auto& weight_name : config.weight_names_to_train) {
|
||||
std::string weight_gradient_name = weight_name + "_grad";
|
||||
if (output_names.find(weight_gradient_name) != output_names.end()) {
|
||||
output_args.push_back(gradient_graph.GetNodeArg(weight_gradient_name));
|
||||
std::vector<const NodeArg*> output_args;
|
||||
for (auto& output_name : split_graphs_info_.user_output_names) {
|
||||
output_args.emplace_back(gradient_graph.GetNodeArg(output_name));
|
||||
}
|
||||
|
||||
// Add initializer gradients to graph outputs.
|
||||
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
|
||||
std::string initializer_gradient_name = initializer_name + "_grad";
|
||||
if (output_names.find(initializer_gradient_name) != output_names.end()) {
|
||||
output_args.emplace_back(gradient_graph.GetNodeArg(initializer_gradient_name));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -153,7 +160,7 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
|
|||
for (const auto& input_name : config.input_names_require_grad) {
|
||||
std::string input_gradient_name = input_name + "_grad";
|
||||
if (output_names.find(input_gradient_name) != output_names.end()) {
|
||||
output_args.push_back(gradient_graph.GetNodeArg(input_gradient_name));
|
||||
output_args.emplace_back(gradient_graph.GetNodeArg(input_gradient_name));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -167,33 +174,33 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
|
|||
ORT_RETURN_IF_ERROR(Model::Load(gradient_model_proto, backward_model_, nullptr, *logger_));
|
||||
|
||||
// Split the graph in the copies of gradient model.
|
||||
ORT_RETURN_IF_ERROR(Split(config, graph_output_names));
|
||||
|
||||
// Serialize the models as output to frontend.
|
||||
std::string gradient_model_str;
|
||||
if (!model_->ToProto().SerializeToString(&gradient_model_str)) {
|
||||
return Status(ONNXRUNTIME, FAIL, "Fail to serialize gradient model to string.");
|
||||
}
|
||||
|
||||
std::string forward_model_str;
|
||||
if (!forward_model_->ToProto().SerializeToString(&forward_model_str)) {
|
||||
return Status(ONNXRUNTIME, FAIL, "Fail to serialize forward model to string.");
|
||||
}
|
||||
|
||||
std::string backward_model_str;
|
||||
if (!backward_model_->ToProto().SerializeToString(&backward_model_str)) {
|
||||
return Status(ONNXRUNTIME, FAIL, "Fail to serialize backward model to string.");
|
||||
}
|
||||
|
||||
models_as_string.push_back(gradient_model_str);
|
||||
models_as_string.push_back(forward_model_str);
|
||||
models_as_string.push_back(backward_model_str);
|
||||
ORT_RETURN_IF_ERROR(Split());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfiguration& config,
|
||||
const std::vector<std::string>& graph_output_names) {
|
||||
std::string SerializeModel(const std::shared_ptr<onnxruntime::Model>& model, const std::string& tag) {
|
||||
std::string model_str;
|
||||
if (!model->ToProto().SerializeToString(&model_str)) {
|
||||
ORT_THROW("Fail to serialize", tag, "model to string.");
|
||||
}
|
||||
|
||||
return model_str;
|
||||
}
|
||||
|
||||
std::string ModuleGradientGraphBuilder::GetGradientModel() const {
|
||||
return SerializeModel(model_, "gradient");
|
||||
}
|
||||
|
||||
std::string ModuleGradientGraphBuilder::GetForwardModel() const {
|
||||
return SerializeModel(forward_model_, "forward");
|
||||
}
|
||||
|
||||
std::string ModuleGradientGraphBuilder::GetBackwardModel() const {
|
||||
return SerializeModel(backward_model_, "backward");
|
||||
}
|
||||
|
||||
Status ModuleGradientGraphBuilder::Split() {
|
||||
// Get forward model, also collect some information for backward model generation.
|
||||
Graph& forward_graph = forward_model_->MainGraph();
|
||||
GraphViewer forward_graph_viewer(forward_graph);
|
||||
|
|
@ -207,7 +214,7 @@ Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfigu
|
|||
auto& node = *forward_graph.GetNode(node_index);
|
||||
// Currently we are using node description to distinguish the forward and backward nodes.
|
||||
if (node.Description() == "Backward pass") {
|
||||
forward_nodes_to_remove.push_back(&node);
|
||||
forward_nodes_to_remove.emplace_back(&node);
|
||||
GetInputAndOutputNames(node, backward_input_names, backward_output_names);
|
||||
} else {
|
||||
GetInputAndOutputNames(node, forward_input_names, forward_output_names);
|
||||
|
|
@ -224,39 +231,41 @@ Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfigu
|
|||
RemoveNodes(forward_graph, forward_nodes_to_remove);
|
||||
FilterInitializers(forward_graph, forward_input_names);
|
||||
|
||||
const std::vector<const NodeArg*>& forward_graph_inputs = forward_graph.GetInputsIncludingInitializers();
|
||||
// All user inputs should be also part of the forward graph inputs.
|
||||
std::vector<const NodeArg*> forward_input_args;
|
||||
for (const NodeArg* node_arg : forward_graph_inputs) {
|
||||
if (forward_input_names.find(node_arg->Name()) != forward_input_names.end()) {
|
||||
forward_input_args.push_back(node_arg);
|
||||
}
|
||||
for (const auto& input_name : split_graphs_info_.user_input_names) {
|
||||
forward_input_args.emplace_back(forward_graph.GetNodeArg(input_name));
|
||||
}
|
||||
|
||||
// Add weights to forward graph inputs.
|
||||
for (const auto& weight_name : config.weight_names_to_train) {
|
||||
forward_input_args.push_back(forward_graph.GetNodeArg(weight_name));
|
||||
// Add initializers to forward graph inputs.
|
||||
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
|
||||
forward_input_args.emplace_back(forward_graph.GetNodeArg(initializer_name));
|
||||
}
|
||||
|
||||
forward_graph.SetInputs(forward_input_args);
|
||||
|
||||
// All user outputs should be also part of the forward graph outputs.
|
||||
std::vector<const NodeArg*> forward_output_args;
|
||||
for (const auto& output_name : graph_output_names) {
|
||||
forward_output_args.push_back(forward_graph.GetNodeArg(output_name));
|
||||
for (const auto& output_name : split_graphs_info_.user_output_names) {
|
||||
forward_output_args.emplace_back(forward_graph.GetNodeArg(output_name));
|
||||
}
|
||||
|
||||
// Add intermediate args to forward graph outputs.
|
||||
for (const auto& intermediate_arg_name : intermediate_arg_names) {
|
||||
// Ignore those duplicates.
|
||||
if (config.output_names.find(intermediate_arg_name) == config.output_names.end()) {
|
||||
forward_output_args.push_back(forward_graph.GetNodeArg(intermediate_arg_name));
|
||||
// Ignore the user outputs.
|
||||
if (std::find(split_graphs_info_.user_output_names.begin(), split_graphs_info_.user_output_names.end(), intermediate_arg_name)
|
||||
== split_graphs_info_.user_output_names.end()) {
|
||||
split_graphs_info_.intermediate_tensor_names.emplace_back(intermediate_arg_name);
|
||||
forward_output_args.emplace_back(forward_graph.GetNodeArg(intermediate_arg_name));
|
||||
}
|
||||
}
|
||||
|
||||
forward_graph.SetOutputs(forward_output_args);
|
||||
|
||||
// Resolve the forward graph, keep the weight initializers for now.
|
||||
// Resolve the forward graph, keep the trainable initializers for now.
|
||||
Graph::ResolveOptions options;
|
||||
options.initializer_names_to_preserve = &config.weight_names_to_train;
|
||||
std::unordered_set<std::string> initializer_names_to_train_set(split_graphs_info_.initializer_names_to_train.begin(), split_graphs_info_.initializer_names_to_train.end());
|
||||
options.initializer_names_to_preserve = &initializer_names_to_train_set;
|
||||
forward_graph.Resolve(options);
|
||||
|
||||
// Get backward graph.
|
||||
|
|
@ -267,44 +276,46 @@ Status ModuleGradientGraphBuilder::Split(const ModuleGradientGraphBuilderConfigu
|
|||
for (auto node_index : backward_node_topology_list) {
|
||||
auto& node = *backward_graph.GetNode(node_index);
|
||||
if (node.Description() != "Backward pass") {
|
||||
backward_nodes_to_remove.push_back(&node);
|
||||
backward_nodes_to_remove.emplace_back(&node);
|
||||
}
|
||||
}
|
||||
|
||||
RemoveNodes(backward_graph, backward_nodes_to_remove);
|
||||
|
||||
const std::vector<const NodeArg*>& backward_graph_inputs = backward_graph.GetInputsIncludingInitializers();
|
||||
std::vector<const NodeArg*> backward_input_args;
|
||||
for (auto& node_arg : backward_graph_inputs) {
|
||||
for (const auto& input_name : split_graphs_info_.user_input_names) {
|
||||
// Only takes those in the backward inputs.
|
||||
if (backward_input_names.find(node_arg->Name()) != backward_input_names.end()) {
|
||||
backward_input_args.push_back(node_arg);
|
||||
if (backward_input_names.find(input_name) != backward_input_names.end()) {
|
||||
split_graphs_info_.backward_user_input_names.emplace_back(input_name);
|
||||
backward_input_args.emplace_back(backward_graph.GetNodeArg(input_name));
|
||||
}
|
||||
}
|
||||
|
||||
// Add weight args to backward graph inputs if any node uses them.
|
||||
for (const auto& weight_name : config.weight_names_to_train) {
|
||||
// Weights will be inputs for backward graph.
|
||||
if (backward_input_names.find(weight_name) != backward_input_names.end()) {
|
||||
backward_input_args.push_back(backward_graph.GetNodeArg(weight_name));
|
||||
backward_graph.RemoveInitializedTensor(weight_name);
|
||||
// Add initializer args to backward graph inputs if any node uses them.
|
||||
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
|
||||
// Some initializers will be inputs for backward graph.
|
||||
if (backward_input_names.find(initializer_name) != backward_input_names.end()) {
|
||||
split_graphs_info_.backward_intializer_names_as_input.emplace_back(initializer_name);
|
||||
backward_input_args.emplace_back(backward_graph.GetNodeArg(initializer_name));
|
||||
backward_graph.RemoveInitializedTensor(initializer_name);
|
||||
}
|
||||
}
|
||||
|
||||
// Add intermediate args to backward graph inputs.
|
||||
for (const auto& intermediate_arg_name : intermediate_arg_names) {
|
||||
for (const auto& intermediate_arg_name : split_graphs_info_.intermediate_tensor_names) {
|
||||
NodeArg* intermediate_node_arg = backward_graph.GetNodeArg(intermediate_arg_name);
|
||||
intermediate_node_arg->UpdateTypeAndShape(*forward_graph.GetNodeArg(intermediate_arg_name), true, true, *logger_);
|
||||
backward_input_args.push_back(intermediate_node_arg);
|
||||
backward_input_args.emplace_back(intermediate_node_arg);
|
||||
}
|
||||
|
||||
backward_graph.SetInputs(backward_input_args);
|
||||
|
||||
// Exclude user outputs from the backward graph.
|
||||
const std::vector<const NodeArg*>& backward_graph_outputs = backward_graph.GetOutputs();
|
||||
std::vector<const NodeArg*> backward_output_args;
|
||||
for (auto& node_arg : backward_graph_outputs) {
|
||||
if (backward_output_names.find(node_arg->Name()) != backward_output_names.end()) {
|
||||
backward_output_args.push_back(node_arg);
|
||||
backward_output_args.emplace_back(node_arg);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,34 +13,54 @@ namespace training {
|
|||
* The training configuration options.
|
||||
*/
|
||||
struct ModuleGradientGraphBuilderConfiguration {
|
||||
// The names of the weights to train.
|
||||
std::unordered_set<std::string> weight_names_to_train{};
|
||||
// The names of inputs that require gradient.
|
||||
std::unordered_set<std::string> input_names_require_grad{};
|
||||
// The names of module outputs.
|
||||
std::unordered_set<std::string> output_names{};
|
||||
// The names of the weights to train.
|
||||
std::vector<std::string> initializer_names_to_train{};
|
||||
// The names of inputs that require gradient.
|
||||
std::vector<std::string> input_names_require_grad{};
|
||||
|
||||
// Gradient graph configuration.
|
||||
bool use_invertible_layernorm_grad = false;
|
||||
bool set_gradients_as_graph_outputs = false;
|
||||
// Gradient graph configuration.
|
||||
bool use_invertible_layernorm_grad = false;
|
||||
bool set_gradients_as_graph_outputs = false;
|
||||
|
||||
// TODO: add GraphTransformerConfiguration
|
||||
// TODO: add mixed precision config
|
||||
// TODO: do we need to support graph with loss?
|
||||
// TODO: add GraphTransformerConfiguration
|
||||
// TODO: add mixed precision config
|
||||
// TODO: do we need to support graph with loss?
|
||||
};
|
||||
|
||||
/**
|
||||
* The information of split graphs for frontend.
|
||||
*/
|
||||
struct SplitGraphsInfo {
|
||||
std::vector<std::string> user_input_names{};
|
||||
std::vector<std::string> initializer_names_to_train{};
|
||||
std::vector<std::string> user_output_names{};
|
||||
std::vector<std::string> backward_user_input_names{};
|
||||
std::vector<std::string> backward_intializer_names_as_input{};
|
||||
std::vector<std::string> intermediate_tensor_names{};
|
||||
std::vector<std::string> user_output_grad_names{};
|
||||
std::vector<std::string> backward_output_grad_names{};
|
||||
};
|
||||
|
||||
class ModuleGradientGraphBuilder {
|
||||
public:
|
||||
Status BuildAndSplit(std::istream& model_istream,
|
||||
const ModuleGradientGraphBuilderConfiguration& config,
|
||||
std::vector<std::string>& models_as_string);
|
||||
const ModuleGradientGraphBuilderConfiguration& config);
|
||||
|
||||
std::string GetGradientModel() const;
|
||||
std::string GetForwardModel() const;
|
||||
std::string GetBackwardModel() const;
|
||||
SplitGraphsInfo GetSplitGraphsInfo() const {
|
||||
return split_graphs_info_;
|
||||
}
|
||||
|
||||
private:
|
||||
Status Split(const ModuleGradientGraphBuilderConfiguration& config,
|
||||
const std::vector<std::string>& graph_output_names);
|
||||
Status Split();
|
||||
|
||||
std::shared_ptr<onnxruntime::Model> model_;
|
||||
std::shared_ptr<onnxruntime::Model> forward_model_;
|
||||
std::shared_ptr<onnxruntime::Model> backward_model_;
|
||||
SplitGraphsInfo split_graphs_info_;
|
||||
|
||||
const logging::Logger* logger_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -357,11 +357,22 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
py::class_<ModuleGradientGraphBuilderConfiguration> module_gradient_graph_builder_config(
|
||||
m, "ModuleGradientGraphBuilderConfiguration", R"pbdoc(Configuration information for module gradient graph builder.)pbdoc");
|
||||
module_gradient_graph_builder_config.def(py::init())
|
||||
.def_readwrite("weight_names_to_train", &ModuleGradientGraphBuilderConfiguration::weight_names_to_train)
|
||||
.def_readwrite("initializer_names_to_train", &ModuleGradientGraphBuilderConfiguration::initializer_names_to_train)
|
||||
.def_readwrite("input_names_require_grad", &ModuleGradientGraphBuilderConfiguration::input_names_require_grad)
|
||||
.def_readwrite("output_names", &ModuleGradientGraphBuilderConfiguration::output_names)
|
||||
.def_readwrite("use_invertible_layernorm_grad", &ModuleGradientGraphBuilderConfiguration::use_invertible_layernorm_grad)
|
||||
.def_readwrite("set_gradients_as_graph_outputs", &ModuleGradientGraphBuilderConfiguration::set_gradients_as_graph_outputs);
|
||||
|
||||
py::class_<SplitGraphsInfo> split_graphs_info(
|
||||
m, "SplitGraphsInfo", R"pbdoc(The information of split graphs for frontend.)pbdoc");
|
||||
split_graphs_info.def(py::init())
|
||||
.def_readwrite("user_input_names", &SplitGraphsInfo::user_input_names)
|
||||
.def_readwrite("initializer_names_to_train", &SplitGraphsInfo::initializer_names_to_train)
|
||||
.def_readwrite("user_output_names", &SplitGraphsInfo::user_output_names)
|
||||
.def_readwrite("backward_user_input_names", &SplitGraphsInfo::backward_user_input_names)
|
||||
.def_readwrite("backward_intializer_names_as_input", &SplitGraphsInfo::backward_intializer_names_as_input)
|
||||
.def_readwrite("intermediate_tensor_names", &SplitGraphsInfo::intermediate_tensor_names)
|
||||
.def_readwrite("user_output_grad_names", &SplitGraphsInfo::user_output_grad_names)
|
||||
.def_readwrite("backward_output_grad_names", &SplitGraphsInfo::backward_output_grad_names);
|
||||
|
||||
py::class_<ModuleGradientGraphBuilder> module_gradient_graph_builder(m, "ModuleGradientGraphBuilder");
|
||||
module_gradient_graph_builder
|
||||
|
|
@ -372,14 +383,19 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
const py::bytes& serialized_model,
|
||||
const ModuleGradientGraphBuilderConfiguration& config) {
|
||||
std::istringstream buffer(serialized_model);
|
||||
std::vector<std::string> models_as_string;
|
||||
ORT_THROW_IF_ERROR(module_gradient_graph_builder->BuildAndSplit(buffer, config, models_as_string));
|
||||
std::vector<py::bytes> models_as_bytes;
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
models_as_bytes.push_back(py::bytes(models_as_string[i]));
|
||||
}
|
||||
|
||||
return models_as_bytes;
|
||||
ORT_THROW_IF_ERROR(module_gradient_graph_builder->BuildAndSplit(buffer, config));
|
||||
})
|
||||
.def("get_gradient_model", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
|
||||
return py::bytes(module_gradient_graph_builder->GetGradientModel());
|
||||
})
|
||||
.def("get_forward_model", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
|
||||
return py::bytes(module_gradient_graph_builder->GetForwardModel());
|
||||
})
|
||||
.def("get_backward_model", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
|
||||
return py::bytes(module_gradient_graph_builder->GetBackwardModel());
|
||||
})
|
||||
.def("get_split_graphs_info", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
|
||||
return module_gradient_graph_builder->GetSplitGraphsInfo();
|
||||
});
|
||||
}
|
||||
} // namespace python
|
||||
|
|
|
|||
|
|
@ -3,122 +3,23 @@ import copy
|
|||
from onnx import shape_inference
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
|
||||
def add_input_from_initializer(model, initializer, docstring=None):
|
||||
new_input = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims, docstring)
|
||||
model.graph.input.append(new_input)
|
||||
|
||||
def add_input(model, name, data_type = None, dims = None, docstring = None):
|
||||
new_input = onnx.helper.make_tensor_value_info(name, data_type, dims, docstring)
|
||||
model.graph.input.append(new_input)
|
||||
|
||||
def add_output(model, name, data_type = None, docstring = None):
|
||||
new_output = model.graph.value_info.add()
|
||||
new_output.name = name
|
||||
if data_type:
|
||||
new_output.type.CopyFrom(data_type)
|
||||
if docstring:
|
||||
new_output.doc_string = docstring
|
||||
model.graph.output.append(new_output)
|
||||
|
||||
def remove_nodes(onnx_model, nodes_to_remove):
|
||||
all_nodes = []
|
||||
for node in onnx_model.graph.node:
|
||||
if node not in nodes_to_remove:
|
||||
all_nodes.append(node)
|
||||
|
||||
onnx_model.graph.ClearField('node')
|
||||
onnx_model.graph.node.extend(all_nodes)
|
||||
|
||||
def split_graph(onnx_model):
|
||||
forward_graph_outputs = set()
|
||||
backward_graph_inputs = set()
|
||||
backward_graph_outputs = set()
|
||||
# Get forward graph
|
||||
forward_model = copy.deepcopy(onnx_model)
|
||||
nodes_to_remove_from_forward_graph = []
|
||||
initializers = {}
|
||||
for initializer in forward_model.graph.initializer:
|
||||
initializers[initializer.name] = initializer
|
||||
forward_graph_initializer_names = set()
|
||||
for node in forward_model.graph.node:
|
||||
if node.doc_string == 'Backward pass':
|
||||
# nodes belongs to backward graph
|
||||
nodes_to_remove_from_forward_graph.append(node)
|
||||
for input in node.input:
|
||||
backward_graph_inputs.add(input)
|
||||
for output in node.output:
|
||||
backward_graph_outputs.add(output)
|
||||
else:
|
||||
# nodes belogs to forward graph
|
||||
for input in node.input:
|
||||
if input in initializers:
|
||||
forward_graph_initializer_names.add(input)
|
||||
for output in node.output:
|
||||
forward_graph_outputs.add(output)
|
||||
|
||||
forward_model.graph.ClearField('initializer')
|
||||
for initializer_name in forward_graph_initializer_names:
|
||||
forward_model.graph.initializer.append(initializers[initializer_name])
|
||||
|
||||
# outputs from forward graph that are also inputs of backwoard graph need to be added as graph output.
|
||||
for output in forward_graph_outputs:
|
||||
if output in backward_graph_inputs:
|
||||
add_output(forward_model, output)
|
||||
|
||||
remove_nodes(forward_model, nodes_to_remove_from_forward_graph)
|
||||
|
||||
# Get backward graph
|
||||
tensor_elem_types = {}
|
||||
infered_model = shape_inference.infer_shapes(onnx_model)
|
||||
for value_info in infered_model.graph.value_info:
|
||||
tensor_elem_types[value_info.name] = value_info.type.tensor_type.elem_type
|
||||
|
||||
backward_model = copy.deepcopy(onnx_model)
|
||||
initializers = {}
|
||||
for initializer in backward_model.graph.initializer:
|
||||
initializers[initializer.name] = initializer
|
||||
|
||||
nodes_to_remove_from_backward_graph = []
|
||||
for node in backward_model.graph.node:
|
||||
if node.doc_string != 'Backward pass':
|
||||
nodes_to_remove_from_backward_graph.append(node)
|
||||
|
||||
# gradient of forward graph output will be the input of backward graph
|
||||
for output in backward_model.graph.output:
|
||||
if output.name + '_grad' in backward_graph_inputs:
|
||||
add_input(backward_model, output.name + '_grad', output.type.tensor_type.elem_type)
|
||||
|
||||
backward_graph_initializer_names = set()
|
||||
for input in backward_graph_inputs:
|
||||
if input in forward_graph_outputs:
|
||||
# inputs of backward graph that are also outputs from forward graph need to be added to backward graph input
|
||||
add_input(backward_model, input, tensor_elem_types[input] if input in tensor_elem_types else 1)
|
||||
elif input in forward_graph_initializer_names:
|
||||
# inputs from forward graph initializers need to be added to backward graph input
|
||||
add_input_from_initializer(backward_model, initializers[input])
|
||||
elif input in initializers:
|
||||
backward_graph_initializer_names.add(input)
|
||||
|
||||
backward_model.graph.ClearField('initializer')
|
||||
for initializer_name in backward_graph_initializer_names:
|
||||
backward_model.graph.initializer.append(initializers[initializer_name])
|
||||
|
||||
# add gradient output to backward graph output
|
||||
# TODO: need to add gradient of graph input to backward graph output
|
||||
new_backward_graph_outputs = set()
|
||||
for output in backward_graph_outputs:
|
||||
if output.endswith('_grad') and output[:-5] in forward_graph_initializer_names:
|
||||
new_backward_graph_outputs.add(output)
|
||||
|
||||
backward_model.graph.ClearField('output')
|
||||
for output in new_backward_graph_outputs:
|
||||
add_output(backward_model, output)
|
||||
|
||||
remove_nodes(backward_model, nodes_to_remove_from_backward_graph)
|
||||
|
||||
return forward_model, backward_model
|
||||
def print_list(name, value):
|
||||
print(name + ':', ', '.join(value))
|
||||
|
||||
|
||||
def dim_str(dim):
|
||||
if dim.HasField('dim_value'):
|
||||
return str(dim.dim_value)
|
||||
elif dim.HasField('dim_param'):
|
||||
return dim.dim_param
|
||||
return 'n/a'
|
||||
|
||||
def print_type(name, type):
|
||||
print('[' + name + ']', 'type:', type.tensor_type.elem_type, '| size:', '[' + ','.join([dim_str(d) for d in type.tensor_type.shape.dim]) + ']')
|
||||
|
||||
|
||||
"""
|
||||
# MNIST
|
||||
original_model = onnx.load('mnist_original.onnx')
|
||||
config = C.ModuleGradientGraphBuilderConfiguration()
|
||||
|
|
@ -137,7 +38,6 @@ onnx.save(models[1], 'mnist_forward.onnx')
|
|||
onnx.save(models[2], 'mnist_backward.onnx')
|
||||
|
||||
|
||||
"""
|
||||
#BERT
|
||||
original_model = onnx.load('BertForSequenceClassification_full_training.onnx')
|
||||
config = C.ModuleGradientGraphBuilderConfiguration()
|
||||
|
|
@ -154,27 +54,67 @@ models = [onnx.load_model_from_string(model_as_string) for model_as_string in C.
|
|||
onnx.save(models[0], 'bert_gradient_graph.onnx')
|
||||
onnx.save(models[1], 'bert_forward.onnx')
|
||||
onnx.save(models[2], 'bert_backward.onnx')
|
||||
|
||||
"""
|
||||
|
||||
#BERT with loss
|
||||
original_model = onnx.load('bert-tiny-loss.onnx')
|
||||
config = C.ModuleGradientGraphBuilderConfiguration()
|
||||
weight_names_to_train = set()
|
||||
initializer_names_to_train = []
|
||||
for initializer in original_model.graph.initializer:
|
||||
if initializer.name.startswith('bert.') or initializer.name.startswith('cls.'):
|
||||
weight_names_to_train.add(initializer.name)
|
||||
config.weight_names_to_train = weight_names_to_train
|
||||
input_names_require_grad = set()
|
||||
input_names_require_grad.add('input3')
|
||||
initializer_names_to_train.append(initializer.name)
|
||||
config.initializer_names_to_train = initializer_names_to_train
|
||||
input_names_require_grad = []
|
||||
input_names_require_grad.append('input3')
|
||||
config.input_names_require_grad = input_names_require_grad
|
||||
output_names = set()
|
||||
#output_names.add('total_loss')
|
||||
for output in original_model.graph.output:
|
||||
output_names.add(output.name)
|
||||
config.output_names = output_names
|
||||
|
||||
models = [onnx.load_model_from_string(model_as_string) for model_as_string in C.ModuleGradientGraphBuilder().build_and_split(original_model.SerializeToString(), config)]
|
||||
onnx.save(models[0], 'bert_gradient_graph.onnx')
|
||||
onnx.save(models[1], 'bert_forward.onnx')
|
||||
onnx.save(models[2], 'bert_backward.onnx')
|
||||
"""
|
||||
module_gradient_graph_builder = C.ModuleGradientGraphBuilder()
|
||||
module_gradient_graph_builder.build_and_split(original_model.SerializeToString(), config)
|
||||
|
||||
forward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_forward_model())
|
||||
backward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_backward_model())
|
||||
onnx.save(onnx.load_model_from_string(module_gradient_graph_builder.get_gradient_model()), 'bert_gradient_graph.onnx')
|
||||
onnx.save(forward_model, 'bert_forward.onnx')
|
||||
onnx.save(backward_model, 'bert_backward.onnx')
|
||||
|
||||
split_graphs_info = module_gradient_graph_builder.get_split_graphs_info()
|
||||
print_list('user_input_names', split_graphs_info.user_input_names)
|
||||
print_list('initializer_names_to_train', split_graphs_info.initializer_names_to_train)
|
||||
print_list('user_output_names', split_graphs_info.user_output_names)
|
||||
print_list('backward_user_input_names', split_graphs_info.backward_user_input_names)
|
||||
print_list('backward_intializer_names_as_input', split_graphs_info.backward_intializer_names_as_input)
|
||||
print_list('intermediate_tensor_names', split_graphs_info.intermediate_tensor_names)
|
||||
print_list('user_output_grad_names', split_graphs_info.user_output_grad_names)
|
||||
print_list('backward_output_grad_names', split_graphs_info.backward_output_grad_names)
|
||||
|
||||
type_map = {}
|
||||
for name in split_graphs_info.user_input_names:
|
||||
type_map[name] = None
|
||||
for name in split_graphs_info.initializer_names_to_train:
|
||||
type_map[name] = None
|
||||
for name in split_graphs_info.user_output_names:
|
||||
type_map[name] = None
|
||||
for name in split_graphs_info.backward_user_input_names:
|
||||
type_map[name] = None
|
||||
for name in split_graphs_info.backward_intializer_names_as_input:
|
||||
type_map[name] = None
|
||||
for name in split_graphs_info.intermediate_tensor_names:
|
||||
type_map[name] = None
|
||||
for name in split_graphs_info.user_output_grad_names:
|
||||
type_map[name] = None
|
||||
for name in split_graphs_info.backward_output_grad_names:
|
||||
type_map[name] = None
|
||||
|
||||
for input in forward_model.graph.input:
|
||||
if input.name in type_map and type_map[input.name] is None:
|
||||
type_map[input.name] = input.type
|
||||
|
||||
for output in forward_model.graph.output:
|
||||
if output.name in type_map and type_map[output.name] is None:
|
||||
type_map[output.name] = output.type
|
||||
output_grad_name = output.name + '_grad'
|
||||
if output_grad_name in type_map and type_map[output_grad_name] is None:
|
||||
type_map[output_grad_name] = output.type
|
||||
|
||||
for key, value in type_map.items():
|
||||
print_type(key, value)
|
||||
|
|
|
|||
Loading…
Reference in a new issue