This commit is contained in:
Vincent Wang 2020-10-21 05:53:30 +00:00 committed by Thiago Crepaldi
parent 26e6d6d004
commit c36c8e14a7
6 changed files with 149 additions and 65 deletions

View file

@ -0,0 +1,66 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/graph/model.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "orttraining/core/framework/module_gradient_graph_builder.h"
#include "orttraining/core/framework/gradient_graph_builder.h"
#include "orttraining/core/session/training_session.h"
#include "orttraining/core/optimizer/graph_transformer_utils.h"
namespace onnxruntime {
namespace training {
std::string ModuleGradientGraphBuilder::Build(std::istream& model_istream, const ModuleGradientGraphBuilderConfiguration& config) {
const logging::Logger& logger = logging::LoggingManager::DefaultLogger(); // use default logger for now.
ONNX_NAMESPACE::ModelProto mp;
Model::Load(model_istream, &mp);
Model model(mp, nullptr, logger);
model.MainGraph().Resolve();
const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration graph_transformer_config{};
GraphTransformerManager graph_transformation_mgr{2};
std::unique_ptr<CPUExecutionProvider> cpu_execution_provider =
onnxruntime::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
auto add_transformers = [&](TransformerLevel level) {
auto transformers_to_register = transformer_utils::GeneratePreTrainingTransformers(
level, config.weight_names_to_train, graph_transformer_config, *cpu_execution_provider, {});
for (auto& entry : transformers_to_register) {
graph_transformation_mgr.Register(std::move(entry), level);
}
};
for (int i = static_cast<int>(TransformerLevel::Level1); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
TransformerLevel level = static_cast<TransformerLevel>(i);
if (TransformerLevel::MaxLevel >= level) {
add_transformers(level);
}
}
// apply transformers
Graph& graph = model.MainGraph();
for (int i = static_cast<int>(TransformerLevel::Level1); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
graph_transformation_mgr.ApplyTransformers(graph, static_cast<TransformerLevel>(i), logger);
}
// TODO: mixed precision transformer.
GradientGraphConfiguration gradient_graph_config{};
gradient_graph_config.use_invertible_layernorm_grad = config.use_invertible_layernorm_grad;
gradient_graph_config.set_gradients_as_graph_outputs = config.set_gradients_as_graph_outputs;
GradientGraphBuilder grad_graph_builder(&model.MainGraph(),
config.output_names,
config.weight_names_to_train,
"", // not support loss name for now.
gradient_graph_config,
logger);
grad_graph_builder.Build();
std::string str;
model.ToProto().SerializeToString(&str);
return str;
}
} // namespace training
} // namespace onnxruntime

View file

@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <string>
#include <unordered_set>
namespace onnxruntime {
namespace training {
/**
* The training configuration options.
*/
struct ModuleGradientGraphBuilderConfiguration {
// The names of the weights to train.
std::unordered_set<std::string> weight_names_to_train{};
// The names of module outputs.
std::unordered_set<std::string> output_names{};
// Gradient graph configuration.
bool use_invertible_layernorm_grad = false;
bool set_gradients_as_graph_outputs = false;
// TODO: add GraphTransformerConfiguration
// TODO: add mixed precision config
// TODO: do we need to support graph with loss?
};
class ModuleGradientGraphBuilder {
public:
std::string Build(std::istream& model_istream, const ModuleGradientGraphBuilderConfiguration& config);
};
} // namespace training
} // namespace onnxruntime

View file

@ -1,32 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/graph/model.h"
#include "orttraining/core/framework/module_transformer.h"
#include "orttraining/core/framework/gradient_graph_builder.h"
namespace onnxruntime {
namespace training {
std::string ModuleTransformer::Transform(std::istream& model_istream,
const std::unordered_set<std::string>& weights_to_train,
const std::unordered_set<std::string>& output_names) {
ONNX_NAMESPACE::ModelProto mp;
Model::Load(model_istream, &mp);
Model model(mp, nullptr, logging::LoggingManager::DefaultLogger());
model.MainGraph().Resolve();
GradientGraphBuilder grad_graph_builder(&model.MainGraph(),
output_names,
weights_to_train,
"",
GradientGraphConfiguration(),
logging::LoggingManager::DefaultLogger());
grad_graph_builder.Build();
std::string str;
model.ToProto().SerializeToString(&str);
return str;
}
} // namespace training
} // namespace onnxruntime

View file

@ -1,19 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <string>
namespace onnxruntime {
namespace training {
class ModuleTransformer {
public:
std::string Transform(std::istream& model_istream,
const std::unordered_set<std::string>& weights_to_train,
const std::unordered_set<std::string>& output_names);
};
} // namespace training
} // namespace onnxruntime

View file

@ -11,7 +11,7 @@
#include "orttraining/core/session/training_session.h"
#include "orttraining/core/graph/optimizer_config.h"
#include "orttraining/core/framework/mpi_context.h"
#include "orttraining/core/framework/module_transformer.h"
#include "orttraining/core/framework/module_gradient_graph_builder.h"
#include "python/onnxruntime_pybind_mlvalue.h"
namespace onnxruntime {
@ -354,17 +354,24 @@ void addObjectMethodsForTraining(py::module& m) {
return static_cast<TrainingSession*>(sess->GetSessionHandle())->IsGraphOutputFp32Node(output_name);
});
py::class_<ModuleTransformer> module_transformer(m, "ModuleTransformer");
module_transformer
py::class_<ModuleGradientGraphBuilderConfiguration> module_gradient_graph_builder_config(
m, "ModuleGradientGraphBuilderConfiguration", R"pbdoc(Configuration information for module gradient graph builder.)pbdoc");
module_gradient_graph_builder_config.def(py::init())
.def_readwrite("weight_names_to_train", &ModuleGradientGraphBuilderConfiguration::weight_names_to_train)
.def_readwrite("output_names", &ModuleGradientGraphBuilderConfiguration::output_names)
.def_readwrite("use_invertible_layernorm_grad", &ModuleGradientGraphBuilderConfiguration::use_invertible_layernorm_grad)
.def_readwrite("set_gradients_as_graph_outputs", &ModuleGradientGraphBuilderConfiguration::set_gradients_as_graph_outputs);
py::class_<ModuleGradientGraphBuilder> module_gradient_graph_builder(m, "ModuleGradientGraphBuilder");
module_gradient_graph_builder
.def(py::init([]() {
return onnxruntime::make_unique<ModuleTransformer>();
return onnxruntime::make_unique<ModuleGradientGraphBuilder>();
}))
.def("transform", [](ModuleTransformer* transformer,
const py::bytes& serialized_model,
const std::unordered_set<std::string>& weights_to_train,
const std::unordered_set<std::string>& output_names) {
.def("build", [](ModuleGradientGraphBuilder* module_gradient_graph_builder,
const py::bytes& serialized_model,
const ModuleGradientGraphBuilderConfiguration& config) {
std::istringstream buffer(serialized_model);
std::string model_as_string = transformer->Transform(buffer, weights_to_train, output_names);
std::string model_as_string = module_gradient_graph_builder->Build(buffer, config);
return py::bytes(model_as_string);
});
}

View file

@ -119,15 +119,41 @@ def split_graph(onnx_model):
return forward_model, backward_model
# MNIST
"""
original_model = onnx.load('mnist_original.onnx')
weights_to_train = set()
config = C.ModuleGradientGraphBuilderConfiguration()
weight_names_to_train = set()
for initializer in original_model.graph.initializer:
weights_to_train.add(initializer.name)
weight_names_to_train.add(initializer.name)
config.weight_names_to_train = weight_names_to_train
output_names = set()
for output in original_model.graph.output:
output_names.add(output.name)
transformed_model = onnx.load_model_from_string(C.ModuleTransformer().transform(original_model.SerializeToString(), weights_to_train, output_names))
onnx.save(transformed_model, 'mnist_transformed.onnx')
forward_model, backward_model = split_graph(transformed_model)
config.output_names = output_names
gradient_graph_model = onnx.load_model_from_string(C.ModuleGradientGraphBuilder().build(original_model.SerializeToString(), config))
onnx.save(gradient_graph_model, 'minst_gradient_graph.onnx')
forward_model, backward_model = split_graph(gradient_graph_model)
onnx.save(forward_model, 'mnist_forward.onnx')
onnx.save(backward_model, 'mnist_backward.onnx')
"""
#BERT
original_model = onnx.load('bert-tiny.onnx')
config = C.ModuleGradientGraphBuilderConfiguration()
weight_names_to_train = set()
for initializer in original_model.graph.initializer:
weight_names_to_train.add(initializer.name)
config.weight_names_to_train = weight_names_to_train
output_names = set()
for output in original_model.graph.output:
output_names.add(output.name)
config.output_names = output_names
gradient_graph_model = onnx.load_model_from_string(C.ModuleGradientGraphBuilder().build(original_model.SerializeToString(), config))
onnx.save(gradient_graph_model, 'bert_gradient_graph.onnx')
forward_model, backward_model = split_graph(gradient_graph_model)
onnx.save(forward_model, 'bert_forward.onnx')
onnx.save(backward_model, 'bert_backward.onnx')