Remove initializers from forward ONNX graph

This commit is contained in:
Thiago Crepaldi 2020-11-19 17:33:55 -08:00
parent 07f5ae95e5
commit 41b88ce91d
2 changed files with 9 additions and 28 deletions

View file

@ -234,6 +234,7 @@ Status ModuleGradientGraphBuilder::Split() {
// Add initializers to forward graph inputs.
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
forward_input_args.emplace_back(forward_graph.GetNodeArg(initializer_name));
forward_graph.RemoveInitializedTensor(initializer_name);
}
forward_graph.SetInputs(forward_input_args);
@ -256,12 +257,7 @@ Status ModuleGradientGraphBuilder::Split() {
forward_graph.SetOutputs(forward_output_args);
// Resolve the forward graph, keep the trainable initializers for now.
Graph::ResolveOptions options;
std::unordered_set<std::string> initializer_names_to_train_set(split_graphs_info_.initializer_names_to_train.begin(),
split_graphs_info_.initializer_names_to_train.end());
options.initializer_names_to_preserve = &initializer_names_to_train_set;
forward_graph.Resolve(options);
forward_graph.Resolve();
// Get backward graph.
Graph& backward_graph = backward_model_->MainGraph();

View file

@ -259,21 +259,17 @@ class ORTModule(torch.nn.Module):
TODO: How IO binding model inputs and outputs affects initializer copies?
ONNX Runtime forward requires an order list of:
* User input: computed from forward InferenceSession
* User input: computed from ONNX forward graph, excluding initializers as input
* Initializers: computed from original PyTorch model parameters
This codes assumes the exported model's inputs and initializers
are the same as the original PyTorch model
'''
# List containing both user inputs and initializers, in this order
result = []
# Inputs
for idx, input_data in enumerate(self._forward_session.get_inputs()):
result.append(inputs[idx])
# User inputs
result = list(inputs[:len(self._onnx_graphs_info.user_input_names)])
# Initializers
for idx, param in enumerate(self._original_module.named_parameters()):
for param in self._original_module.named_parameters():
result.append(param[1])
return result
@ -284,20 +280,9 @@ class ORTModule(torch.nn.Module):
TODO: Input gradient is being ignored for MVP
'''
# Dictionary containing both inputs and initializers
result = {}
# Inputs
result_len = 0
for idx, input_data in enumerate(self._forward_session.get_inputs()):
result_len += 1
result.update({input_data.name: inputs[idx]})
# Initializers
for param in self._original_module.named_parameters():
result.update({param[0]: inputs[result_len]})
result_len += 1
return result
forward_input_names = [*self._onnx_graphs_info.user_input_names,
*self._onnx_graphs_info.initializer_names_to_train]
return dict(zip(forward_input_names, inputs))
def _convert_backward_input_list_to_dict(self, *inputs):
'''Convert backward `*inputs` list to dict