mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-09 00:30:53 +00:00
Remove initializers from forward ONNX graph
This commit is contained in:
parent
07f5ae95e5
commit
41b88ce91d
2 changed files with 9 additions and 28 deletions
|
|
@ -234,6 +234,7 @@ Status ModuleGradientGraphBuilder::Split() {
|
|||
// Add initializers to forward graph inputs.
|
||||
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
|
||||
forward_input_args.emplace_back(forward_graph.GetNodeArg(initializer_name));
|
||||
forward_graph.RemoveInitializedTensor(initializer_name);
|
||||
}
|
||||
|
||||
forward_graph.SetInputs(forward_input_args);
|
||||
|
|
@ -256,12 +257,7 @@ Status ModuleGradientGraphBuilder::Split() {
|
|||
|
||||
forward_graph.SetOutputs(forward_output_args);
|
||||
|
||||
// Resolve the forward graph, keep the trainable initializers for now.
|
||||
Graph::ResolveOptions options;
|
||||
std::unordered_set<std::string> initializer_names_to_train_set(split_graphs_info_.initializer_names_to_train.begin(),
|
||||
split_graphs_info_.initializer_names_to_train.end());
|
||||
options.initializer_names_to_preserve = &initializer_names_to_train_set;
|
||||
forward_graph.Resolve(options);
|
||||
forward_graph.Resolve();
|
||||
|
||||
// Get backward graph.
|
||||
Graph& backward_graph = backward_model_->MainGraph();
|
||||
|
|
|
|||
|
|
@ -259,21 +259,17 @@ class ORTModule(torch.nn.Module):
|
|||
TODO: How IO binding model inputs and outputs affects initializer copies?
|
||||
|
||||
ONNX Runtime forward requires an order list of:
|
||||
* User input: computed from forward InferenceSession
|
||||
* User input: computed from ONNX forward graph, excluding initializers as input
|
||||
* Initializers: computed from original PyTorch model parameters
|
||||
|
||||
This codes assumes the exported model's inputs and initializers
|
||||
are the same as the original PyTorch model
|
||||
'''
|
||||
# List containing both user inputs and initializers, in this order
|
||||
result = []
|
||||
|
||||
# Inputs
|
||||
for idx, input_data in enumerate(self._forward_session.get_inputs()):
|
||||
result.append(inputs[idx])
|
||||
# User inputs
|
||||
result = list(inputs[:len(self._onnx_graphs_info.user_input_names)])
|
||||
|
||||
# Initializers
|
||||
for idx, param in enumerate(self._original_module.named_parameters()):
|
||||
for param in self._original_module.named_parameters():
|
||||
result.append(param[1])
|
||||
|
||||
return result
|
||||
|
|
@ -284,20 +280,9 @@ class ORTModule(torch.nn.Module):
|
|||
TODO: Input gradient is being ignored for MVP
|
||||
'''
|
||||
# Dictionary containing both inputs and initializers
|
||||
result = {}
|
||||
|
||||
# Inputs
|
||||
result_len = 0
|
||||
for idx, input_data in enumerate(self._forward_session.get_inputs()):
|
||||
result_len += 1
|
||||
result.update({input_data.name: inputs[idx]})
|
||||
|
||||
# Initializers
|
||||
for param in self._original_module.named_parameters():
|
||||
result.update({param[0]: inputs[result_len]})
|
||||
result_len += 1
|
||||
|
||||
return result
|
||||
forward_input_names = [*self._onnx_graphs_info.user_input_names,
|
||||
*self._onnx_graphs_info.initializer_names_to_train]
|
||||
return dict(zip(forward_input_names, inputs))
|
||||
|
||||
def _convert_backward_input_list_to_dict(self, *inputs):
|
||||
'''Convert backward `*inputs` list to dict
|
||||
|
|
|
|||
Loading…
Reference in a new issue