From c5aeaa94196bbcbdead24f37dae930d744fc2186 Mon Sep 17 00:00:00 2001 From: baijumeswani Date: Tue, 11 May 2021 12:26:56 -0700 Subject: [PATCH] Support for unused model initializers (#7631) * Support for unused model initializers * Change graph_info.initializer* to sets --- .../core/framework/ortmodule_graph_builder.cc | 16 +++--- .../core/framework/ortmodule_graph_builder.h | 6 +-- .../ortmodule/_graph_execution_manager.py | 11 ++-- .../training/ortmodule/_inference_manager.py | 3 +- .../python/training/ortmodule/_io.py | 7 +-- .../training/ortmodule/_training_manager.py | 12 ++--- .../orttraining/test/python/_test_helpers.py | 2 +- .../python/orttraining_test_ortmodule_api.py | 50 +++++++++++++++++++ 8 files changed, 83 insertions(+), 24 deletions(-) diff --git a/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc b/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc index f80d3f4595..96c6bfa92c 100644 --- a/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc +++ b/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc @@ -39,10 +39,12 @@ Status OrtModuleGraphBuilder::Initialize(std::istream& model_istream, graph_info_.user_output_names.emplace_back(node_arg->Name()); } - graph_info_.initializer_names_to_train.assign(config.initializer_names_to_train.begin(), - config.initializer_names_to_train.end()); - graph_info_.initializer_names.assign(config.initializer_names.begin(), - config.initializer_names.end()); + graph_info_.initializer_names_to_train = std::unordered_set( + config_.initializer_names_to_train.begin(), + config_.initializer_names_to_train.end()); + graph_info_.initializer_names = std::unordered_set( + config_.initializer_names.begin(), + config_.initializer_names.end()); std::vector input_args; for (const auto& input_name : graph_info_.user_input_names) { @@ -50,7 +52,7 @@ Status OrtModuleGraphBuilder::Initialize(std::istream& model_istream, } // Remove all the initializers from the graph and move them to graph inputs. - for (const auto& initializer_name : graph_info_.initializer_names) { + for (const auto& initializer_name : config_.initializer_names) { const NodeArg* node_arg = graph.GetNodeArg(initializer_name); ORT_ENFORCE(node_arg != nullptr); input_args.emplace_back(node_arg); @@ -314,11 +316,11 @@ void OrtModuleGraphBuilder::ReorderOutputs() { // Add initializer gradients to graph outputs. graph_info_.initializer_grad_names_to_train.clear(); - for (const auto& initializer_name : graph_info_.initializer_names_to_train) { + for (const auto& initializer_name : config_.initializer_names_to_train) { std::string initializer_gradient_name = GradientBuilderBase::GradientName(initializer_name); ORT_ENFORCE(gradient_output_arg_map.find(initializer_gradient_name) != gradient_output_arg_map.end(), "Trainable initializer grad is not found on gradient graph."); - graph_info_.initializer_grad_names_to_train.emplace_back(initializer_gradient_name); + graph_info_.initializer_grad_names_to_train.emplace(initializer_gradient_name); new_output_args.emplace_back(gradient_output_arg_map[initializer_gradient_name]); } diff --git a/orttraining/orttraining/core/framework/ortmodule_graph_builder.h b/orttraining/orttraining/core/framework/ortmodule_graph_builder.h index 9925be3a20..139ef468f9 100644 --- a/orttraining/orttraining/core/framework/ortmodule_graph_builder.h +++ b/orttraining/orttraining/core/framework/ortmodule_graph_builder.h @@ -44,11 +44,11 @@ struct GraphInfo { // Map from user input names to corresponding user input grad names for those user inputs that require grad. std::unordered_map user_input_grad_names{}; // All initializers (trainable as well as non trainable). - std::vector initializer_names{}; + std::unordered_set initializer_names{}; // Trainable initializers. - std::vector initializer_names_to_train{}; + std::unordered_set initializer_names_to_train{}; // Trainable initializer grad names, ordered according to initializer_names_to_train. - std::vector initializer_grad_names_to_train{}; + std::unordered_set initializer_grad_names_to_train{}; // The user outputs. std::vector user_output_names{}; // Indices of output grads that are non-differentiable. diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 79e6009e3e..f3c995f994 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -282,10 +282,15 @@ class GraphExecutionManager(ABC): def _initialize_graph_builder(self, training): """Creates a new OrtModuleGraphBuilder, initializes it and saves it to self._graph_builder""" + # All initializer names along with user inputs are a part of the onnx graph inputs + # since the onnx model was exported with the flag keep_initializers_as_inputs=True + onnx_initializer_names = {p.name for p in self._onnx_model.graph.input} + # TODO: PyTorch exporter bug: changes the initializer order in ONNX model - initializer_names = [name for name, _ in self._flattened_module.named_parameters()] - initializer_names_to_train = [name for name, - param in self._flattened_module.named_parameters() if param.requires_grad] + initializer_names = [name for name, _ in self._flattened_module.named_parameters() + if name in onnx_initializer_names] + initializer_names_to_train = [name for name, param in self._flattened_module.named_parameters() + if param.requires_grad and name in onnx_initializer_names] # Build and optimize the full graph grad_builder_config = C.OrtModuleGraphBuilderConfiguration() diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index 57f0ca579b..5359e05b6a 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -89,7 +89,8 @@ class InferenceManager(GraphExecutionManager): self._optimized_onnx_model, self._device, *_io._combine_input_buffers_initializers( - self._flattened_module.named_parameters(), + [param for name, param in self._flattened_module.named_parameters() + if name in self._graph_info.initializer_names], self._graph_info.user_input_names, self._input_info, self._flattened_module.named_buffers(), diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index b673130ebb..6be66804fe 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -73,7 +73,7 @@ class _InputInfo(object): if name in self.keyword_names} return args, kwargs -def _combine_input_buffers_initializers(param_names, onnx_input_names, input_info, buffer_names, inputs, kwargs, device): +def _combine_input_buffers_initializers(params, onnx_input_names, input_info, buffer_names, inputs, kwargs, device): '''Creates forward `*inputs` list from user input and PyTorch initializers ONNX Runtime forward requires an ordered list of: @@ -120,8 +120,9 @@ def _combine_input_buffers_initializers(param_names, onnx_input_names, input_inf else: raise RuntimeError(f'Input is present in ONNX graph but not provided: {name}.') - # Initializers - result.extend([param[1] for param in param_names]) + # params is a list of all initializers known to the onnx graph + result.extend(params) + return result diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 00f19e1ea5..ed7b30f98e 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -173,22 +173,22 @@ class TrainingManager(GraphExecutionManager): # Append gradients of initializer to results # Go over each initializer, check if it required grad and append to results accordingly - initializer_names_to_train_set = set(self._graph_info.initializer_names_to_train) initializer_index = num_user_input_grads - for initializer_name in self._graph_info.initializer_names: - if initializer_name in initializer_names_to_train_set: + for initializer_name, _ in self._flattened_module.named_parameters(): + if initializer_name in self._graph_info.initializer_names_to_train: results.append(_utils._ortvalue_to_torch_tensor(backward_outputs[initializer_index])) initializer_index += 1 else: results.append(None) - + return tuple(results) return _io.unflatten_user_output(self._module_output_schema, self._graph_info.user_output_names, _ORTModuleFunction.apply( *_io._combine_input_buffers_initializers( - self._flattened_module.named_parameters(), + [param for name, param in self._flattened_module.named_parameters() + if name in self._graph_info.initializer_names], self._graph_info.user_input_names, self._input_info, self._flattened_module.named_buffers(), @@ -238,7 +238,7 @@ class TrainingManager(GraphExecutionManager): initializer_names_to_train_set_user_model = {name for name, param in self._flattened_module.named_parameters() if param.requires_grad} - initializer_names_to_train_set_onnx_graph = set(self._graph_info.initializer_names_to_train) \ + initializer_names_to_train_set_onnx_graph = self._graph_info.initializer_names_to_train \ if self._graph_info else None # If inputs requiring gradient change from forward to the next, the module_gradient_graph_builder diff --git a/orttraining/orttraining/test/python/_test_helpers.py b/orttraining/orttraining/test/python/_test_helpers.py index ead4f42a4a..95cbbf773e 100644 --- a/orttraining/orttraining/test/python/_test_helpers.py +++ b/orttraining/orttraining/test/python/_test_helpers.py @@ -160,7 +160,7 @@ def assert_gradients_match_and_reset_gradient(ort_model, pt_model, none_pt_param assert pt_name in ort_name if pt_name in none_pt_params: assert pt_param.grad is None - assert not torch.is_nonzero(torch.count_nonzero(ort_param.grad)) + assert ort_param.grad is None or not torch.is_nonzero(torch.count_nonzero(ort_param.grad)) else: assert_values_are_close(ort_param.grad, pt_param.grad, rtol=rtol, atol=atol) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index df26250ed0..f0203a57c8 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -2362,3 +2362,53 @@ def test_model_with_registered_buffer_and_dropped_parameters(): # Ensure that no exceptions are raised out = model(bool_argument, x) + +def test_unused_parameters(): + class UnusedParameterNet(torch.nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super(UnusedParameterNet, self).__init__() + + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.relu = torch.nn.ReLU() + # fc2 is an unused initializer which will be dropped after export + self.fc2 = torch.nn.Linear(hidden_size, num_classes) + self.register_buffer("buffer", torch.ones(hidden_size)) + + def forward(self, input1): + out = self.fc1(input1) + out = self.relu(out) + out = out + self.buffer + return out + + device = 'cuda' + + N, D_in, H, D_out = 64, 784, 500, 10 + model = UnusedParameterNet(D_in, H, D_out).to(device) + ort_model = ORTModule(copy.deepcopy(model)) + + # Make sure model runs without any exception + for _ in range(5): + x = torch.randn(N, D_in, device=device) + y = copy.deepcopy(x) + + out_pt = model(x) + out_ort = ort_model(y) + loss_pt = out_pt.sum() + loss_pt.backward() + loss_ort = out_ort.sum() + loss_ort.backward() + _test_helpers.assert_values_are_close(out_ort, out_pt) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, model, + none_pt_params=['fc2.weight', 'fc2.bias']) + + # Also try in eval mode + model.eval() + ort_model.eval() + + x = torch.randn(N, D_in, device=device) + y = copy.deepcopy(x) + + # Make sure model runs without any exception + out_pt = model(x) + out_ort = ort_model(y) + _test_helpers.assert_values_are_close(out_ort, out_pt)