diff --git a/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc b/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc new file mode 100644 index 0000000000..3e05da2390 --- /dev/null +++ b/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/model.h" +#include "core/providers/cpu/cpu_execution_provider.h" +#include "orttraining/core/framework/module_gradient_graph_builder.h" +#include "orttraining/core/framework/gradient_graph_builder.h" +#include "orttraining/core/session/training_session.h" +#include "orttraining/core/optimizer/graph_transformer_utils.h" + +namespace onnxruntime { +namespace training { + +std::string ModuleGradientGraphBuilder::Build(std::istream& model_istream, const ModuleGradientGraphBuilderConfiguration& config) { + const logging::Logger& logger = logging::LoggingManager::DefaultLogger(); // use default logger for now. + ONNX_NAMESPACE::ModelProto mp; + Model::Load(model_istream, &mp); + Model model(mp, nullptr, logger); + model.MainGraph().Resolve(); + + const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration graph_transformer_config{}; + GraphTransformerManager graph_transformation_mgr{2}; + std::unique_ptr cpu_execution_provider = + onnxruntime::make_unique(CPUExecutionProviderInfo()); + + auto add_transformers = [&](TransformerLevel level) { + auto transformers_to_register = transformer_utils::GeneratePreTrainingTransformers( + level, config.weight_names_to_train, graph_transformer_config, *cpu_execution_provider, {}); + for (auto& entry : transformers_to_register) { + graph_transformation_mgr.Register(std::move(entry), level); + } + }; + + for (int i = static_cast(TransformerLevel::Level1); i <= static_cast(TransformerLevel::MaxLevel); i++) { + TransformerLevel level = static_cast(i); + if (TransformerLevel::MaxLevel >= level) { + add_transformers(level); + } + } + + // apply transformers + Graph& graph = model.MainGraph(); + for (int i = static_cast(TransformerLevel::Level1); i <= static_cast(TransformerLevel::MaxLevel); i++) { + graph_transformation_mgr.ApplyTransformers(graph, static_cast(i), logger); + } + + // TODO: mixed precision transformer. + + GradientGraphConfiguration gradient_graph_config{}; + gradient_graph_config.use_invertible_layernorm_grad = config.use_invertible_layernorm_grad; + gradient_graph_config.set_gradients_as_graph_outputs = config.set_gradients_as_graph_outputs; + GradientGraphBuilder grad_graph_builder(&model.MainGraph(), + config.output_names, + config.weight_names_to_train, + "", // not support loss name for now. + gradient_graph_config, + logger); + grad_graph_builder.Build(); + + std::string str; + model.ToProto().SerializeToString(&str); + return str; +} + +} // namespace training +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/framework/module_gradient_graph_builder.h b/orttraining/orttraining/core/framework/module_gradient_graph_builder.h new file mode 100644 index 0000000000..6ced3750b8 --- /dev/null +++ b/orttraining/orttraining/core/framework/module_gradient_graph_builder.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +namespace onnxruntime { +namespace training { + +/** + * The training configuration options. + */ +struct ModuleGradientGraphBuilderConfiguration { +// The names of the weights to train. +std::unordered_set weight_names_to_train{}; +// The names of module outputs. +std::unordered_set output_names{}; + +// Gradient graph configuration. +bool use_invertible_layernorm_grad = false; +bool set_gradients_as_graph_outputs = false; + +// TODO: add GraphTransformerConfiguration +// TODO: add mixed precision config +// TODO: do we need to support graph with loss? +}; + +class ModuleGradientGraphBuilder { + public: + std::string Build(std::istream& model_istream, const ModuleGradientGraphBuilderConfiguration& config); +}; + +} // namespace training +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/framework/module_transformer.cc b/orttraining/orttraining/core/framework/module_transformer.cc deleted file mode 100644 index 2dea4a327b..0000000000 --- a/orttraining/orttraining/core/framework/module_transformer.cc +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/graph/model.h" -#include "orttraining/core/framework/module_transformer.h" -#include "orttraining/core/framework/gradient_graph_builder.h" - -namespace onnxruntime { -namespace training { - -std::string ModuleTransformer::Transform(std::istream& model_istream, - const std::unordered_set& weights_to_train, - const std::unordered_set& output_names) { - ONNX_NAMESPACE::ModelProto mp; - Model::Load(model_istream, &mp); - Model model(mp, nullptr, logging::LoggingManager::DefaultLogger()); - model.MainGraph().Resolve(); - - GradientGraphBuilder grad_graph_builder(&model.MainGraph(), - output_names, - weights_to_train, - "", - GradientGraphConfiguration(), - logging::LoggingManager::DefaultLogger()); - grad_graph_builder.Build(); - std::string str; - model.ToProto().SerializeToString(&str); - return str; -} - -} // namespace training -} // namespace onnxruntime diff --git a/orttraining/orttraining/core/framework/module_transformer.h b/orttraining/orttraining/core/framework/module_transformer.h deleted file mode 100644 index b4d8315969..0000000000 --- a/orttraining/orttraining/core/framework/module_transformer.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -namespace onnxruntime { -namespace training { - -class ModuleTransformer { - public: - std::string Transform(std::istream& model_istream, - const std::unordered_set& weights_to_train, - const std::unordered_set& output_names); -}; - -} // namespace training -} // namespace onnxruntime diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 9388df7ba8..33950aa1cf 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -11,7 +11,7 @@ #include "orttraining/core/session/training_session.h" #include "orttraining/core/graph/optimizer_config.h" #include "orttraining/core/framework/mpi_context.h" -#include "orttraining/core/framework/module_transformer.h" +#include "orttraining/core/framework/module_gradient_graph_builder.h" #include "python/onnxruntime_pybind_mlvalue.h" namespace onnxruntime { @@ -354,17 +354,24 @@ void addObjectMethodsForTraining(py::module& m) { return static_cast(sess->GetSessionHandle())->IsGraphOutputFp32Node(output_name); }); - py::class_ module_transformer(m, "ModuleTransformer"); - module_transformer + py::class_ module_gradient_graph_builder_config( + m, "ModuleGradientGraphBuilderConfiguration", R"pbdoc(Configuration information for module gradient graph builder.)pbdoc"); + module_gradient_graph_builder_config.def(py::init()) + .def_readwrite("weight_names_to_train", &ModuleGradientGraphBuilderConfiguration::weight_names_to_train) + .def_readwrite("output_names", &ModuleGradientGraphBuilderConfiguration::output_names) + .def_readwrite("use_invertible_layernorm_grad", &ModuleGradientGraphBuilderConfiguration::use_invertible_layernorm_grad) + .def_readwrite("set_gradients_as_graph_outputs", &ModuleGradientGraphBuilderConfiguration::set_gradients_as_graph_outputs); + + py::class_ module_gradient_graph_builder(m, "ModuleGradientGraphBuilder"); + module_gradient_graph_builder .def(py::init([]() { - return onnxruntime::make_unique(); + return onnxruntime::make_unique(); })) - .def("transform", [](ModuleTransformer* transformer, - const py::bytes& serialized_model, - const std::unordered_set& weights_to_train, - const std::unordered_set& output_names) { + .def("build", [](ModuleGradientGraphBuilder* module_gradient_graph_builder, + const py::bytes& serialized_model, + const ModuleGradientGraphBuilderConfiguration& config) { std::istringstream buffer(serialized_model); - std::string model_as_string = transformer->Transform(buffer, weights_to_train, output_names); + std::string model_as_string = module_gradient_graph_builder->Build(buffer, config); return py::bytes(model_as_string); }); } diff --git a/samples/python/mnist/graph_spliter.py b/samples/python/mnist/graph_spliter.py index 67cc625f92..f857f6d202 100644 --- a/samples/python/mnist/graph_spliter.py +++ b/samples/python/mnist/graph_spliter.py @@ -119,15 +119,41 @@ def split_graph(onnx_model): return forward_model, backward_model +# MNIST +""" original_model = onnx.load('mnist_original.onnx') -weights_to_train = set() +config = C.ModuleGradientGraphBuilderConfiguration() +weight_names_to_train = set() for initializer in original_model.graph.initializer: - weights_to_train.add(initializer.name) + weight_names_to_train.add(initializer.name) +config.weight_names_to_train = weight_names_to_train output_names = set() for output in original_model.graph.output: output_names.add(output.name) -transformed_model = onnx.load_model_from_string(C.ModuleTransformer().transform(original_model.SerializeToString(), weights_to_train, output_names)) -onnx.save(transformed_model, 'mnist_transformed.onnx') -forward_model, backward_model = split_graph(transformed_model) +config.output_names = output_names + +gradient_graph_model = onnx.load_model_from_string(C.ModuleGradientGraphBuilder().build(original_model.SerializeToString(), config)) +onnx.save(gradient_graph_model, 'minst_gradient_graph.onnx') +forward_model, backward_model = split_graph(gradient_graph_model) onnx.save(forward_model, 'mnist_forward.onnx') onnx.save(backward_model, 'mnist_backward.onnx') +""" + + +#BERT +original_model = onnx.load('bert-tiny.onnx') +config = C.ModuleGradientGraphBuilderConfiguration() +weight_names_to_train = set() +for initializer in original_model.graph.initializer: + weight_names_to_train.add(initializer.name) +config.weight_names_to_train = weight_names_to_train +output_names = set() +for output in original_model.graph.output: + output_names.add(output.name) +config.output_names = output_names + +gradient_graph_model = onnx.load_model_from_string(C.ModuleGradientGraphBuilder().build(original_model.SerializeToString(), config)) +onnx.save(gradient_graph_model, 'bert_gradient_graph.onnx') +forward_model, backward_model = split_graph(gradient_graph_model) +onnx.save(forward_model, 'bert_forward.onnx') +onnx.save(backward_model, 'bert_backward.onnx') \ No newline at end of file