From 7729bb3c8d19956f4fc26ee0729dc1d8285b9eb1 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Mon, 7 Dec 2020 07:34:01 +0000 Subject: [PATCH] Add initial dynamic axes support --- .../module_gradient_graph_builder.cc | 103 +++++++++++------- .../framework/module_gradient_graph_builder.h | 7 +- .../python/orttraining_pybind_state.cc | 13 ++- .../orttraining/python/training/ortmodule.py | 63 ++++------- ...training_test_ortmodule_bert_classifier.py | 21 ++-- 5 files changed, 107 insertions(+), 100 deletions(-) diff --git a/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc b/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc index af5a4fb215..e9373e3554 100644 --- a/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc +++ b/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc @@ -43,13 +43,16 @@ void FilterInitializers(Graph& graph, const std::unordered_set& inp } } -Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream, - const ModuleGradientGraphBuilderConfiguration& config) { - logger_ = &logging::LoggingManager::DefaultLogger(); // use default logger for now. +Status ModuleGradientGraphBuilder::Initialize(std::istream& model_istream, + const ModuleGradientGraphBuilderConfiguration& config) { + // We need to apply the pre-training transformers before the gradient graph builder so we can build + // an optimized gradient graph. The constant folding transformer depends on concrete shapes, without + // constant folding with concrete shapes, shapes of some intermediate tensors will fail to infer. + // This means we need to "apply transformers -> build gradient graph -> split" each time we have different + // concrete input shapes. So this init func is just to save the original graph and config. ONNX_NAMESPACE::ModelProto model_proto; ORT_RETURN_IF_ERROR(Model::Load(model_istream, &model_proto)); ORT_RETURN_IF_ERROR(Model::Load(model_proto, model_, nullptr, *logger_)); - ORT_RETURN_IF_ERROR(model_->MainGraph().Resolve()); // Handle original model inputs, outputs and trainable initializers. const std::vector& graph_inputs = model_->MainGraph().GetInputsIncludingInitializers(); @@ -65,6 +68,35 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream, split_graphs_info_.initializer_names_to_train.assign(config.initializer_names_to_train.begin(), config.initializer_names_to_train.end()); + config_ = config; + return Status::OK(); +} + +Status ModuleGradientGraphBuilder::BuildAndSplit(const std::vector>& input_shapes) { + // Make a copy of the original model. + auto model_proto = model_->ToProto(); + std::shared_ptr model_copied; + ORT_RETURN_IF_ERROR(Model::Load(model_proto, model_copied, nullptr, *logger_)); + Graph& graph = model_copied->MainGraph(); + + // Replace the input shapes. + std::vector input_args; + size_t input_index = 0; + for (const auto& input_name : split_graphs_info_.user_input_names) { + NodeArg* input_node_arg = graph.GetNodeArg(input_name); + ONNX_NAMESPACE::TensorShapeProto new_shape; + for (size_t i = 0; i < input_shapes[input_index].size(); i++) { + new_shape.add_dim()->set_dim_value(input_shapes[input_index][i]); + } + + input_node_arg->SetShape(new_shape); + input_args.emplace_back(input_node_arg); + input_index++; + } + + graph.SetInputs(input_args); + ORT_RETURN_IF_ERROR(graph.Resolve()); + // Register and apply transformers for pre-training. const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration graph_transformer_config{}; GraphTransformerManager graph_transformation_mgr{2}; @@ -72,8 +104,8 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream, onnxruntime::make_unique(CPUExecutionProviderInfo()); std::unordered_set x_node_arg_names; - std::set_union(config.initializer_names_to_train.begin(), config.initializer_names_to_train.end(), - config.input_names_require_grad.begin(), config.input_names_require_grad.end(), + std::set_union(config_.initializer_names_to_train.begin(), config_.initializer_names_to_train.end(), + config_.input_names_require_grad.begin(), config_.input_names_require_grad.end(), std::inserter(x_node_arg_names, x_node_arg_names.begin())); auto add_transformers = [&](TransformerLevel level) { std::unordered_map updated_weight_names{}; @@ -91,41 +123,39 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream, } } - Graph& graph = model_->MainGraph(); for (int i = static_cast(TransformerLevel::Level1); i <= static_cast(TransformerLevel::MaxLevel); i++) { ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, static_cast(i), *logger_)); } - // TODO: mixed precision transformer. - // Build gradient graph. GradientGraphConfiguration gradient_graph_config{}; - gradient_graph_config.use_invertible_layernorm_grad = config.use_invertible_layernorm_grad; - gradient_graph_config.set_gradients_as_graph_outputs = config.set_gradients_as_graph_outputs; + gradient_graph_config.use_invertible_layernorm_grad = config_.use_invertible_layernorm_grad; + gradient_graph_config.set_gradients_as_graph_outputs = config_.set_gradients_as_graph_outputs; std::unordered_set y_node_arg_names(split_graphs_info_.user_output_names.begin(), split_graphs_info_.user_output_names.end()); - GradientGraphBuilder grad_graph_builder(&model_->MainGraph(), y_node_arg_names, x_node_arg_names, - "", // not support loss name for now. + GradientGraphBuilder grad_graph_builder(&graph, y_node_arg_names, x_node_arg_names, + "", gradient_graph_config, *logger_); ORT_RETURN_IF_ERROR(grad_graph_builder.Build()); // Fix inputs/outputs related to gradients. - Graph& gradient_graph = model_->MainGraph(); - GraphViewer gradient_graph_viewer(gradient_graph); - const auto& node_topology_list = gradient_graph_viewer.GetNodesInTopologicalOrder(); + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); std::unordered_set input_names; std::unordered_set output_names; for (auto node_index : node_topology_list) { - auto& node = *gradient_graph.GetNode(node_index); + auto& node = *graph.GetNode(node_index); GetInputAndOutputNames(node, input_names, output_names); } - std::vector input_args; + input_args.clear(); for (auto& input_name : split_graphs_info_.user_input_names) { - input_args.emplace_back(gradient_graph.GetNodeArg(input_name)); + input_args.emplace_back(graph.GetNodeArg(input_name)); } // Add the entry points of gradients (normally loss_gard) to the graph inputs. Using the order of graph outputs. + split_graphs_info_.user_output_grad_names.clear(); + split_graphs_info_.backward_output_grad_names.clear(); for (const auto& output_name : split_graphs_info_.user_output_names) { std::string output_gradient_name = output_name + "_grad"; if (input_names.find(output_gradient_name) != input_names.end()) { @@ -133,48 +163,48 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream, // Only add to graph input when it's not an output of a node. if (output_names.find(output_gradient_name) == output_names.end()) { split_graphs_info_.backward_output_grad_names.emplace_back(output_gradient_name); - NodeArg* output_gradient_node_arg = gradient_graph.GetNodeArg(output_gradient_name); - output_gradient_node_arg->UpdateTypeAndShape(*gradient_graph.GetNodeArg(output_name), true, true, *logger_); + NodeArg* output_gradient_node_arg = graph.GetNodeArg(output_gradient_name); + output_gradient_node_arg->UpdateTypeAndShape(*graph.GetNodeArg(output_name), true, true, *logger_); input_args.emplace_back(output_gradient_node_arg); } } } - gradient_graph.SetInputs(input_args); + graph.SetInputs(input_args); std::vector output_args; for (auto& output_name : split_graphs_info_.user_output_names) { - output_args.emplace_back(gradient_graph.GetNodeArg(output_name)); + output_args.emplace_back(graph.GetNodeArg(output_name)); } // Add initializer gradients to graph outputs. + split_graphs_info_.initializer_grad_names_to_train.clear(); for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) { std::string initializer_gradient_name = initializer_name + "_grad"; if (output_names.find(initializer_gradient_name) != output_names.end()) { split_graphs_info_.initializer_grad_names_to_train.emplace_back(initializer_gradient_name); - output_args.emplace_back(gradient_graph.GetNodeArg(initializer_gradient_name)); + output_args.emplace_back(graph.GetNodeArg(initializer_gradient_name)); } } // Add input gradients to graph outputs if it's required. - for (const auto& input_name : config.input_names_require_grad) { + for (const auto& input_name : config_.input_names_require_grad) { std::string input_gradient_name = input_name + "_grad"; if (output_names.find(input_gradient_name) != output_names.end()) { - output_args.emplace_back(gradient_graph.GetNodeArg(input_gradient_name)); + output_args.emplace_back(graph.GetNodeArg(input_gradient_name)); } } - gradient_graph.SetOutputs(output_args); + graph.SetOutputs(output_args); + graph.Resolve(); - gradient_graph.Resolve(); - - // Run the transformers again mainly for backward part. + // Run the transformers again mainly for backward part, e.g., constant fold from those Shape nodes in backward graph. for (int i = static_cast(TransformerLevel::Level1); i <= static_cast(TransformerLevel::MaxLevel); i++) { - ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(gradient_graph, static_cast(i), *logger_)); + ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, static_cast(i), *logger_)); } // Create two copies of gradient model for forward and backward models respectively. - auto gradient_model_proto = model_->ToProto(); + auto gradient_model_proto = model_copied->ToProto(); ORT_RETURN_IF_ERROR(Model::Load(gradient_model_proto, forward_model_, nullptr, *logger_)); ORT_RETURN_IF_ERROR(Model::Load(gradient_model_proto, backward_model_, nullptr, *logger_)); @@ -193,8 +223,6 @@ std::string SerializeModel(const std::shared_ptr& model, con return model_str; } -std::string ModuleGradientGraphBuilder::GetGradientModel() const { return SerializeModel(model_, "gradient"); } - std::string ModuleGradientGraphBuilder::GetForwardModel() const { return SerializeModel(forward_model_, "forward"); } std::string ModuleGradientGraphBuilder::GetBackwardModel() const { return SerializeModel(backward_model_, "backward"); } @@ -251,6 +279,7 @@ Status ModuleGradientGraphBuilder::Split() { } // Add intermediate args to forward graph outputs. + split_graphs_info_.intermediate_tensor_names.clear(); for (const auto& intermediate_arg_name : intermediate_arg_names) { // Ignore the user outputs. if (std::find(split_graphs_info_.user_output_names.begin(), split_graphs_info_.user_output_names.end(), @@ -261,7 +290,6 @@ Status ModuleGradientGraphBuilder::Split() { } forward_graph.SetOutputs(forward_output_args); - forward_graph.Resolve(); // Get backward graph. @@ -279,6 +307,7 @@ Status ModuleGradientGraphBuilder::Split() { RemoveNodes(backward_graph, backward_nodes_to_remove); // User inputs to backward graph inputs. + split_graphs_info_.backward_user_input_names.clear(); std::vector backward_input_args; for (const auto& input_name : split_graphs_info_.user_input_names) { // Only takes those in the backward inputs. @@ -289,6 +318,7 @@ Status ModuleGradientGraphBuilder::Split() { } // Add initializer args to backward graph inputs if any node uses them. + split_graphs_info_.backward_intializer_names_as_input.clear(); for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) { // Some initializers will be inputs for backward graph. if (backward_input_names.find(initializer_name) != backward_input_names.end()) { @@ -322,11 +352,8 @@ Status ModuleGradientGraphBuilder::Split() { } backward_graph.SetOutputs(backward_output_args); - FilterInitializers(backward_graph, backward_input_names); - backward_graph.Resolve(); - return Status::OK(); } diff --git a/orttraining/orttraining/core/framework/module_gradient_graph_builder.h b/orttraining/orttraining/core/framework/module_gradient_graph_builder.h index 33abc38e84..491b5dc19a 100644 --- a/orttraining/orttraining/core/framework/module_gradient_graph_builder.h +++ b/orttraining/orttraining/core/framework/module_gradient_graph_builder.h @@ -44,9 +44,9 @@ struct SplitGraphsInfo { class ModuleGradientGraphBuilder { public: - Status BuildAndSplit(std::istream& model_istream, const ModuleGradientGraphBuilderConfiguration& config); + Status Initialize(std::istream& model_istream, const ModuleGradientGraphBuilderConfiguration& config); + Status BuildAndSplit(const std::vector>& input_shapes); - std::string GetGradientModel() const; std::string GetForwardModel() const; std::string GetBackwardModel() const; SplitGraphsInfo GetSplitGraphsInfo() const { return split_graphs_info_; } @@ -59,7 +59,8 @@ class ModuleGradientGraphBuilder { std::shared_ptr backward_model_; SplitGraphsInfo split_graphs_info_; - const logging::Logger* logger_; + ModuleGradientGraphBuilderConfiguration config_; + const logging::Logger* logger_ = &logging::LoggingManager::DefaultLogger(); // use default logger for now. }; } // namespace training diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index e8a83ee665..b9099dee1e 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -380,14 +380,15 @@ void addObjectMethodsForTraining(py::module& m) { .def(py::init([]() { return onnxruntime::make_unique(); })) - .def("build_and_split", [](ModuleGradientGraphBuilder* module_gradient_graph_builder, - const py::bytes& serialized_model, - const ModuleGradientGraphBuilderConfiguration& config) { + .def("initialize", [](ModuleGradientGraphBuilder* module_gradient_graph_builder, + const py::bytes& serialized_model, + const ModuleGradientGraphBuilderConfiguration& config) { std::istringstream buffer(serialized_model); - ORT_THROW_IF_ERROR(module_gradient_graph_builder->BuildAndSplit(buffer, config)); + ORT_THROW_IF_ERROR(module_gradient_graph_builder->Initialize(buffer, config)); }) - .def("get_gradient_model", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) { - return py::bytes(module_gradient_graph_builder->GetGradientModel()); + .def("build_and_split", [](ModuleGradientGraphBuilder* module_gradient_graph_builder, + const std::vector>& input_shapes) { + ORT_THROW_IF_ERROR(module_gradient_graph_builder->BuildAndSplit(input_shapes)); }) .def("get_forward_model", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) { return py::bytes(module_gradient_graph_builder->GetForwardModel()); diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index 14fddf1f12..944d309df2 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -56,7 +56,7 @@ def _onnx_value_info_to_buffer_tensor(value_info, device): class ORTModule(torch.nn.Module): - def __init__(self, module): + def __init__(self, module, dynamic_axes=None): assert isinstance(module, torch.nn.Module), "'module' mst be a torch.nn.Module" super(ORTModule, self).__init__() @@ -66,8 +66,12 @@ class ORTModule(torch.nn.Module): # User module is wrapped to use its initializers and save computed gradients self._original_module = module + self._dynamic_axes = dynamic_axes self._onnx_training = None + self._curr_inputs_size = None + self._module_gradient_graph_builder = None + # Forward pass self._onnx_forward = None self._forward_session = None @@ -154,19 +158,28 @@ class ORTModule(torch.nn.Module): if not self._onnx_forward or self._require_export: self._require_export = False - self._onnx_training = ORTModule._get_forward_graph(self._original_module, *inputs, **kwargs) + self._onnx_training = ORTModule._get_forward_graph(self._original_module, self._dynamic_axes, *inputs, **kwargs) grad_builder_config = C.ModuleGradientGraphBuilderConfiguration() # TODO: PyTorch exporter bug: changes the initializer order initializer_names = [p[0] for p in self._original_module.named_parameters()] - onnx_gradient, self._onnx_forward, self._onnx_backward, self._onnx_graphs_info = \ - ORTModule._build_fw_bw_grad_graphs(self._onnx_training, grad_builder_config, - initializer_names, - self._save_onnx) + grad_builder_config.initializer_names_to_train = initializer_names + grad_builder_config.input_names_require_grad = [] + self._module_gradient_graph_builder = C.ModuleGradientGraphBuilder() + self._module_gradient_graph_builder.initialize(self._onnx_training.SerializeToString(), grad_builder_config) if self._save_onnx: onnx.save(self._onnx_training, self._save_onnx_prefix + '_full_training.onnx') - onnx.save(onnx_gradient, self._save_onnx_prefix + '_with_grad.onnx') + + inputs_size = [list(input.size()) for input in inputs if input is not None] + if self._curr_inputs_size is None or self._curr_inputs_size != inputs_size: + self._curr_inputs_size = inputs_size + self._module_gradient_graph_builder.build_and_split(self._curr_inputs_size) + self._onnx_forward = onnx.load_model_from_string(self._module_gradient_graph_builder.get_forward_model()) + self._onnx_backward = onnx.load_model_from_string(self._module_gradient_graph_builder.get_backward_model()) + self._onnx_graphs_info = self._module_gradient_graph_builder.get_split_graphs_info() + + if self._save_onnx: onnx.save(self._onnx_forward, self._save_onnx_prefix + '_forward.onnx') onnx.save(self._onnx_backward, self._save_onnx_prefix + '_backward.onnx') @@ -174,6 +187,7 @@ class ORTModule(torch.nn.Module): self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString()) # IO binding + # TODO: we should try to reuse the output buffers as some of the output tensors are same sizes, expecially the backward graph outputs. self._forward_io_binding = self._forward_session.io_binding() self._forward_output_buffers = {} for output in self._onnx_forward.graph.output: @@ -335,7 +349,7 @@ class ORTModule(torch.nn.Module): @staticmethod - def _get_forward_graph(module, *inputs, **kwargs): + def _get_forward_graph(module, dynamic_axes, *inputs, **kwargs): '''Exports PyTorch `module` to ONNX with training flag, using `*inputs` as input TODO: How to support dynamic axes? Dimensions are determined by samples @@ -363,36 +377,7 @@ class ORTModule(torch.nn.Module): input_names=input_names, opset_version=ONNX_OPSET_VERSION, do_constant_folding=False, - training=torch.onnx.TrainingMode.TRAINING) + training=torch.onnx.TrainingMode.TRAINING, + dynamic_axes=dynamic_axes) return onnx.load_model_from_string(f.getvalue()) - - - @staticmethod - def _build_fw_bw_grad_graphs(forward_graph, config, initializer_names=[], include_gradient_model=False): - '''Adds gradient nodes on top of an existing ONNX graph (with training flag)''' - if not config.initializer_names_to_train: - if not initializer_names: - initializer_names_to_train = [] - for initializer in forward_graph.graph.initializer: - initializer_names_to_train.append(initializer.name) - config.initializer_names_to_train = initializer_names_to_train - else: - config.initializer_names_to_train = initializer_names - - # TODO: Add support to input with grad required - config.input_names_require_grad = [] - # input_names_require_grad = [] - # input_names_require_grad.append('input.1') - # config.input_names_require_grad = input_names_require_grad - - module_gradient_graph_builder = C.ModuleGradientGraphBuilder() - module_gradient_graph_builder.build_and_split(forward_graph.SerializeToString(), config) - forward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_forward_model()) - backward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_backward_model()) - gradient_model = None - if include_gradient_model: - gradient_model = onnx.load_model_from_string(module_gradient_graph_builder.get_gradient_model()) - split_graphs_info = module_gradient_graph_builder.get_split_graphs_info() - - return gradient_model, forward_model, backward_model, split_graphs_info diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py index c2ac0473ff..2be6206d23 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py @@ -49,11 +49,6 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args): if step == args.train_steps: break - # TODO: Dynamic axis is not supported yet - if batch[0].shape[0] != args.batch_size: - logging.warning(f'Dynamic axis is not supported yet {len(batch)}/{args.batch_size}') - continue - # Unpack this training batch from our dataloader. # # As we unpack the batch, we'll also copy each tensor to the GPU using the @@ -159,12 +154,6 @@ def test(model, validation_dataloader, device, args): # Evaluate data for one epoch for batch in validation_dataloader: - - # TODO: Dynamic axis is not supported yet - if batch[0].shape[0] != args.test_batch_size: - logging.warning(f'Dynamic axis is not supported yet {len(batch)}/{args.batch_size}') - continue - # Add batch to GPU batch = tuple(t.to(device) for t in batch) @@ -336,8 +325,8 @@ def main(): help='disables ONNX Runtime training') parser.add_argument('--batch-size', type=int, default=32, metavar='N', help='input batch size for training (default: 32)') - parser.add_argument('--test-batch-size', type=int, default=32, metavar='N', - help='input batch size for testing (default: 32)') + parser.add_argument('--test-batch-size', type=int, default=64, metavar='N', + help='input batch size for testing (default: 64)') parser.add_argument('--view-graphs', action='store_true', default=False, help='views forward and backward graphs') parser.add_argument('--no-cuda', action='store_true', default=False, @@ -391,7 +380,11 @@ def main(): ) if not args.pytorch_only: - model = ORTModule(model) + dynamic_axes = {'input_ids': {0: 'batch_size', 1: 'seq_len'}, + 'attention_mask': {0: 'batch_size', 1: 'seq_len'}, + 'labels': {0: 'batch_size'}, + '210': {0: 'batch'}} + model = ORTModule(model, dynamic_axes) # TODO: change it to False to stop saving ONNX models model._save_onnx = True