diff --git a/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc b/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc index 55df8a0a25..1e76789901 100644 --- a/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc +++ b/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc @@ -234,6 +234,7 @@ Status ModuleGradientGraphBuilder::Split() { // Add initializers to forward graph inputs. for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) { forward_input_args.emplace_back(forward_graph.GetNodeArg(initializer_name)); + forward_graph.RemoveInitializedTensor(initializer_name); } forward_graph.SetInputs(forward_input_args); @@ -256,12 +257,7 @@ Status ModuleGradientGraphBuilder::Split() { forward_graph.SetOutputs(forward_output_args); - // Resolve the forward graph, keep the trainable initializers for now. - Graph::ResolveOptions options; - std::unordered_set initializer_names_to_train_set(split_graphs_info_.initializer_names_to_train.begin(), - split_graphs_info_.initializer_names_to_train.end()); - options.initializer_names_to_preserve = &initializer_names_to_train_set; - forward_graph.Resolve(options); + forward_graph.Resolve(); // Get backward graph. Graph& backward_graph = backward_model_->MainGraph(); diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index fcd7e15aab..59cc53bc15 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -259,21 +259,17 @@ class ORTModule(torch.nn.Module): TODO: How IO binding model inputs and outputs affects initializer copies? ONNX Runtime forward requires an order list of: - * User input: computed from forward InferenceSession + * User input: computed from ONNX forward graph, excluding initializers as input * Initializers: computed from original PyTorch model parameters This codes assumes the exported model's inputs and initializers are the same as the original PyTorch model ''' - # List containing both user inputs and initializers, in this order - result = [] - - # Inputs - for idx, input_data in enumerate(self._forward_session.get_inputs()): - result.append(inputs[idx]) + # User inputs + result = list(inputs[:len(self._onnx_graphs_info.user_input_names)]) # Initializers - for idx, param in enumerate(self._original_module.named_parameters()): + for param in self._original_module.named_parameters(): result.append(param[1]) return result @@ -284,20 +280,9 @@ class ORTModule(torch.nn.Module): TODO: Input gradient is being ignored for MVP ''' # Dictionary containing both inputs and initializers - result = {} - - # Inputs - result_len = 0 - for idx, input_data in enumerate(self._forward_session.get_inputs()): - result_len += 1 - result.update({input_data.name: inputs[idx]}) - - # Initializers - for param in self._original_module.named_parameters(): - result.update({param[0]: inputs[result_len]}) - result_len += 1 - - return result + forward_input_names = [*self._onnx_graphs_info.user_input_names, + *self._onnx_graphs_info.initializer_names_to_train] + return dict(zip(forward_input_names, inputs)) def _convert_backward_input_list_to_dict(self, *inputs): '''Convert backward `*inputs` list to dict