mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
refactor
This commit is contained in:
parent
26e6d6d004
commit
c36c8e14a7
6 changed files with 149 additions and 65 deletions
|
|
@ -0,0 +1,66 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/graph/model.h"
|
||||
#include "core/providers/cpu/cpu_execution_provider.h"
|
||||
#include "orttraining/core/framework/module_gradient_graph_builder.h"
|
||||
#include "orttraining/core/framework/gradient_graph_builder.h"
|
||||
#include "orttraining/core/session/training_session.h"
|
||||
#include "orttraining/core/optimizer/graph_transformer_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
||||
std::string ModuleGradientGraphBuilder::Build(std::istream& model_istream, const ModuleGradientGraphBuilderConfiguration& config) {
|
||||
const logging::Logger& logger = logging::LoggingManager::DefaultLogger(); // use default logger for now.
|
||||
ONNX_NAMESPACE::ModelProto mp;
|
||||
Model::Load(model_istream, &mp);
|
||||
Model model(mp, nullptr, logger);
|
||||
model.MainGraph().Resolve();
|
||||
|
||||
const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration graph_transformer_config{};
|
||||
GraphTransformerManager graph_transformation_mgr{2};
|
||||
std::unique_ptr<CPUExecutionProvider> cpu_execution_provider =
|
||||
onnxruntime::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
|
||||
auto add_transformers = [&](TransformerLevel level) {
|
||||
auto transformers_to_register = transformer_utils::GeneratePreTrainingTransformers(
|
||||
level, config.weight_names_to_train, graph_transformer_config, *cpu_execution_provider, {});
|
||||
for (auto& entry : transformers_to_register) {
|
||||
graph_transformation_mgr.Register(std::move(entry), level);
|
||||
}
|
||||
};
|
||||
|
||||
for (int i = static_cast<int>(TransformerLevel::Level1); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
|
||||
TransformerLevel level = static_cast<TransformerLevel>(i);
|
||||
if (TransformerLevel::MaxLevel >= level) {
|
||||
add_transformers(level);
|
||||
}
|
||||
}
|
||||
|
||||
// apply transformers
|
||||
Graph& graph = model.MainGraph();
|
||||
for (int i = static_cast<int>(TransformerLevel::Level1); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
|
||||
graph_transformation_mgr.ApplyTransformers(graph, static_cast<TransformerLevel>(i), logger);
|
||||
}
|
||||
|
||||
// TODO: mixed precision transformer.
|
||||
|
||||
GradientGraphConfiguration gradient_graph_config{};
|
||||
gradient_graph_config.use_invertible_layernorm_grad = config.use_invertible_layernorm_grad;
|
||||
gradient_graph_config.set_gradients_as_graph_outputs = config.set_gradients_as_graph_outputs;
|
||||
GradientGraphBuilder grad_graph_builder(&model.MainGraph(),
|
||||
config.output_names,
|
||||
config.weight_names_to_train,
|
||||
"", // not support loss name for now.
|
||||
gradient_graph_config,
|
||||
logger);
|
||||
grad_graph_builder.Build();
|
||||
|
||||
std::string str;
|
||||
model.ToProto().SerializeToString(&str);
|
||||
return str;
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
||||
/**
|
||||
* The training configuration options.
|
||||
*/
|
||||
struct ModuleGradientGraphBuilderConfiguration {
|
||||
// The names of the weights to train.
|
||||
std::unordered_set<std::string> weight_names_to_train{};
|
||||
// The names of module outputs.
|
||||
std::unordered_set<std::string> output_names{};
|
||||
|
||||
// Gradient graph configuration.
|
||||
bool use_invertible_layernorm_grad = false;
|
||||
bool set_gradients_as_graph_outputs = false;
|
||||
|
||||
// TODO: add GraphTransformerConfiguration
|
||||
// TODO: add mixed precision config
|
||||
// TODO: do we need to support graph with loss?
|
||||
};
|
||||
|
||||
class ModuleGradientGraphBuilder {
|
||||
public:
|
||||
std::string Build(std::istream& model_istream, const ModuleGradientGraphBuilderConfiguration& config);
|
||||
};
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/graph/model.h"
|
||||
#include "orttraining/core/framework/module_transformer.h"
|
||||
#include "orttraining/core/framework/gradient_graph_builder.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
||||
std::string ModuleTransformer::Transform(std::istream& model_istream,
|
||||
const std::unordered_set<std::string>& weights_to_train,
|
||||
const std::unordered_set<std::string>& output_names) {
|
||||
ONNX_NAMESPACE::ModelProto mp;
|
||||
Model::Load(model_istream, &mp);
|
||||
Model model(mp, nullptr, logging::LoggingManager::DefaultLogger());
|
||||
model.MainGraph().Resolve();
|
||||
|
||||
GradientGraphBuilder grad_graph_builder(&model.MainGraph(),
|
||||
output_names,
|
||||
weights_to_train,
|
||||
"",
|
||||
GradientGraphConfiguration(),
|
||||
logging::LoggingManager::DefaultLogger());
|
||||
grad_graph_builder.Build();
|
||||
std::string str;
|
||||
model.ToProto().SerializeToString(&str);
|
||||
return str;
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
||||
class ModuleTransformer {
|
||||
public:
|
||||
std::string Transform(std::istream& model_istream,
|
||||
const std::unordered_set<std::string>& weights_to_train,
|
||||
const std::unordered_set<std::string>& output_names);
|
||||
};
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -11,7 +11,7 @@
|
|||
#include "orttraining/core/session/training_session.h"
|
||||
#include "orttraining/core/graph/optimizer_config.h"
|
||||
#include "orttraining/core/framework/mpi_context.h"
|
||||
#include "orttraining/core/framework/module_transformer.h"
|
||||
#include "orttraining/core/framework/module_gradient_graph_builder.h"
|
||||
#include "python/onnxruntime_pybind_mlvalue.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -354,17 +354,24 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
return static_cast<TrainingSession*>(sess->GetSessionHandle())->IsGraphOutputFp32Node(output_name);
|
||||
});
|
||||
|
||||
py::class_<ModuleTransformer> module_transformer(m, "ModuleTransformer");
|
||||
module_transformer
|
||||
py::class_<ModuleGradientGraphBuilderConfiguration> module_gradient_graph_builder_config(
|
||||
m, "ModuleGradientGraphBuilderConfiguration", R"pbdoc(Configuration information for module gradient graph builder.)pbdoc");
|
||||
module_gradient_graph_builder_config.def(py::init())
|
||||
.def_readwrite("weight_names_to_train", &ModuleGradientGraphBuilderConfiguration::weight_names_to_train)
|
||||
.def_readwrite("output_names", &ModuleGradientGraphBuilderConfiguration::output_names)
|
||||
.def_readwrite("use_invertible_layernorm_grad", &ModuleGradientGraphBuilderConfiguration::use_invertible_layernorm_grad)
|
||||
.def_readwrite("set_gradients_as_graph_outputs", &ModuleGradientGraphBuilderConfiguration::set_gradients_as_graph_outputs);
|
||||
|
||||
py::class_<ModuleGradientGraphBuilder> module_gradient_graph_builder(m, "ModuleGradientGraphBuilder");
|
||||
module_gradient_graph_builder
|
||||
.def(py::init([]() {
|
||||
return onnxruntime::make_unique<ModuleTransformer>();
|
||||
return onnxruntime::make_unique<ModuleGradientGraphBuilder>();
|
||||
}))
|
||||
.def("transform", [](ModuleTransformer* transformer,
|
||||
const py::bytes& serialized_model,
|
||||
const std::unordered_set<std::string>& weights_to_train,
|
||||
const std::unordered_set<std::string>& output_names) {
|
||||
.def("build", [](ModuleGradientGraphBuilder* module_gradient_graph_builder,
|
||||
const py::bytes& serialized_model,
|
||||
const ModuleGradientGraphBuilderConfiguration& config) {
|
||||
std::istringstream buffer(serialized_model);
|
||||
std::string model_as_string = transformer->Transform(buffer, weights_to_train, output_names);
|
||||
std::string model_as_string = module_gradient_graph_builder->Build(buffer, config);
|
||||
return py::bytes(model_as_string);
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -119,15 +119,41 @@ def split_graph(onnx_model):
|
|||
return forward_model, backward_model
|
||||
|
||||
|
||||
# MNIST
|
||||
"""
|
||||
original_model = onnx.load('mnist_original.onnx')
|
||||
weights_to_train = set()
|
||||
config = C.ModuleGradientGraphBuilderConfiguration()
|
||||
weight_names_to_train = set()
|
||||
for initializer in original_model.graph.initializer:
|
||||
weights_to_train.add(initializer.name)
|
||||
weight_names_to_train.add(initializer.name)
|
||||
config.weight_names_to_train = weight_names_to_train
|
||||
output_names = set()
|
||||
for output in original_model.graph.output:
|
||||
output_names.add(output.name)
|
||||
transformed_model = onnx.load_model_from_string(C.ModuleTransformer().transform(original_model.SerializeToString(), weights_to_train, output_names))
|
||||
onnx.save(transformed_model, 'mnist_transformed.onnx')
|
||||
forward_model, backward_model = split_graph(transformed_model)
|
||||
config.output_names = output_names
|
||||
|
||||
gradient_graph_model = onnx.load_model_from_string(C.ModuleGradientGraphBuilder().build(original_model.SerializeToString(), config))
|
||||
onnx.save(gradient_graph_model, 'minst_gradient_graph.onnx')
|
||||
forward_model, backward_model = split_graph(gradient_graph_model)
|
||||
onnx.save(forward_model, 'mnist_forward.onnx')
|
||||
onnx.save(backward_model, 'mnist_backward.onnx')
|
||||
"""
|
||||
|
||||
|
||||
#BERT
|
||||
original_model = onnx.load('bert-tiny.onnx')
|
||||
config = C.ModuleGradientGraphBuilderConfiguration()
|
||||
weight_names_to_train = set()
|
||||
for initializer in original_model.graph.initializer:
|
||||
weight_names_to_train.add(initializer.name)
|
||||
config.weight_names_to_train = weight_names_to_train
|
||||
output_names = set()
|
||||
for output in original_model.graph.output:
|
||||
output_names.add(output.name)
|
||||
config.output_names = output_names
|
||||
|
||||
gradient_graph_model = onnx.load_model_from_string(C.ModuleGradientGraphBuilder().build(original_model.SerializeToString(), config))
|
||||
onnx.save(gradient_graph_model, 'bert_gradient_graph.onnx')
|
||||
forward_model, backward_model = split_graph(gradient_graph_model)
|
||||
onnx.save(forward_model, 'bert_forward.onnx')
|
||||
onnx.save(backward_model, 'bert_backward.onnx')
|
||||
Loading…
Reference in a new issue