onnxruntime/orttraining/orttraining/python/training/ortmodule.py
2020-12-15 09:03:08 -08:00

653 lines
31 KiB
Python

import copy
import io
import logging
import onnx
import onnxruntime
import os
import torch
import warnings
from onnxruntime.capi import _pybind_state as C
from . import _utils
ONNX_OPSET_VERSION = 12
class ORTModule(torch.nn.Module):
def __init__(self, module):
assert isinstance(module, torch.nn.Module), "'module' mst be a torch.nn.Module"
super(ORTModule, self).__init__()
# User module is wrapped to use its initializers and save computed gradients
self._original_module = module
self._original_module_grad_output_len = -1
self._original_module_forward_input_grads = []
self._onnx_training = None
self._onnx_training_inputs_desc = []
self._onnx_training_outputs_desc = []
self._onnx_gradient = None
self._grad_builder_config = C.ModuleGradientGraphBuilderConfiguration()
# Forward pass
self._onnx_forward = None
self._forward_session = None
self._onnx_forward_initializers_desc = []
self._onnx_forward_inputs_desc = []
self._onnx_forward_outputs_desc = []
self._onnx_forward_intermediate_outputs_desc = []
# Backward pass
self._onnx_backward = None
self._backward_session = None
self._onnx_backward_initializers_desc = []
self._onnx_backward_inputs_desc = []
self._onnx_backward_gradient_inputs_desc = []
self._onnx_backward_outputs_desc = []
# Log level
self._loglevel = getattr(logging, 'WARNING')
# TODO: debug flags
self._save_onnx = False
self._save_onnx_prefix = ''
def forward(self, *inputs, **kwargs):
'''Forward pass starts here and continues at `_ORTModuleFunction.forward`
ONNX model is exported the first time this method is executed.
Next, a full training graph is splitted in forward and backward graph which are used
to instantiate ONNX Runtime InferenceSession`s
TODO: #ImproveGraphSplitting
Additionally to that, several descriptor lists are generated to help identify
model input, output, initializer, intermediate and gradient tensors.
'''
if not self._onnx_forward:
self._onnx_training = ORTModule._get_forward_graph(self._original_module, *inputs, **kwargs)
self._onnx_gradient = ORTModule._build_gradient_graph(self._onnx_training, self._grad_builder_config)
self._onnx_forward, self._onnx_backward = ORTModule._split_forward_and_backward(self._onnx_gradient, self._grad_builder_config.weight_names_to_train)
if self._save_onnx:
onnx.save(self._onnx_training, self._save_onnx_prefix + '_full_training.onnx')
onnx.save(self._onnx_gradient, self._save_onnx_prefix + '_with_grad.onnx')
onnx.save(self._onnx_forward, self._save_onnx_prefix + '_forward.onnx')
onnx.save(self._onnx_backward, self._save_onnx_prefix + '_backward.onnx')
# TODO: hard-coding to CPU only
self._forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString(), providers=['CPUExecutionProvider'])
self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString(), providers=['CPUExecutionProvider'])
# Forward I/O description
if not self._onnx_training_inputs_desc:
self._onnx_training_inputs_desc = self._get_input_from_graph(self._onnx_training)
logging.debug(f'Training inputs:\n\t {self._onnx_training_inputs_desc}')
if not self._onnx_training_outputs_desc:
self._onnx_training_outputs_desc = self._get_output_from_graph(self._onnx_training)
logging.debug(f'Training outputs:\n\t {self._onnx_training_outputs_desc}')
if not self._onnx_forward_initializers_desc:
self._onnx_forward_initializers_desc = self._get_initializer_from_graph(self._onnx_forward)
logging.debug(f'Forward initializers:\n\t {self._onnx_forward_initializers_desc}')
if not self._onnx_forward_inputs_desc:
self._onnx_forward_inputs_desc = self._get_input_from_graph(self._onnx_forward)
logging.debug(f'Forward inputs:\n\t {self._onnx_forward_inputs_desc}')
if not self._onnx_forward_outputs_desc:
self._onnx_forward_outputs_desc = self._get_output_from_graph(self._onnx_forward)
logging.debug(f'Forward outputs:\n\t {self._onnx_forward_outputs_desc}')
if not self._onnx_forward_intermediate_outputs_desc:
self._onnx_forward_intermediate_outputs_desc = self._get_intermediate_from_forward_graph(self._onnx_forward)
logging.debug(f'Forward intermediate outputs:\n\t {self._onnx_forward_intermediate_outputs_desc}')
# Backward I/O description
if not self._onnx_backward_initializers_desc:
self._onnx_backward_initializers_desc = self._get_input_from_graph(self._onnx_backward, True)
logging.debug(f'Backward initializers: {self._onnx_backward_initializers_desc}')
if not self._onnx_backward_inputs_desc:
self._onnx_backward_inputs_desc = self._get_input_from_graph(self._onnx_backward, False, self._onnx_backward_initializers_desc)
logging.debug(f'Backward inputs: {self._onnx_backward_inputs_desc}')
if not self._onnx_backward_gradient_inputs_desc:
self._onnx_backward_gradient_inputs_desc = self._get_gradient_input_from_graph(self._onnx_backward, self._onnx_forward_inputs_desc, self._onnx_forward_initializers_desc, self._onnx_forward_intermediate_outputs_desc)
logging.debug(f'Backward gradient inputs: {self._onnx_backward_gradient_inputs_desc}')
if not self._onnx_backward_outputs_desc:
self._onnx_backward_outputs_desc = self._get_output_from_graph(self._onnx_backward)
logging.debug(f'Backward outputs: {self._onnx_backward_outputs_desc}')
# Use a custom torch.autograd.Function to associate self.backward_graph as the
# gradient implementation for self.forward_graph.
class _ORTModuleFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, *inputs, **kwargs):
'''Performs forward pass based on user input and PyTorch initializer
TODO: **kwargs are not supported
Model outputs are returned to the user
The following tensors are stashed (in order) for backward pass
* (Partial) user input
* (Partial) Initializers
* Intermediate tensors
TODO: #ImproveGraphSplitting
String matching to separate user input from initializer
'''
# Convert input to dict of torch tensors
data_dict = self._convert_forward_input_list_to_dict(*inputs)
# Convert dict of torch tensors to dict of numpy arrays (ORT BE requirement)
data_dict_numpy = self._convert_dict_torch_to_numpy(data_dict)
# Feed forward
outputs, intermediate = self._run_forward_graph(data_dict_numpy)
outputs = tuple(torch.from_numpy(item) for item in outputs)
# Save input, initializers and intermediate tensors to be used during backward
initializer_names = [item['name'] for item in self._onnx_backward_initializers_desc]
input_names = [item['name'] for item in self._onnx_backward_inputs_desc if item['name'] not in initializer_names]
ctx_input = tuple(v for k,v in data_dict.items() if k in input_names)
ctx_initializer = tuple(v for k,v in data_dict.items() if k in initializer_names)
intermediate = tuple(torch.from_numpy(item) for item in intermediate)
ctx.save_for_backward(*[*ctx_input, *ctx_initializer, *intermediate])
# TODO: Support original module output (currently dict is not supported)
if len(outputs) == 1:
return outputs[0]
return outputs
@staticmethod
def backward(ctx, *grad_output):
'''Performs backward pass based on grad wrt output and internal state
Internal state is composed of:
* Tensor stashed (in a particular order) during forward:
* (partial) user input, (partial) initializers and intermediate tensors
TODO: #ImproveGraphSplitting
Length of `*grad_output` is needed to detect intermediate tensors during backward pass
TODO: Input gradient is hard-coded to torch.tensor([1.])
'''
saved_tensors = ctx.saved_tensors
# Used to create backward input
if self._original_module_grad_output_len == -1:
self._original_module_grad_output_len = len(grad_output)
grad_weights = self._run_backward_graph(*[*saved_tensors, *grad_output])
result = [torch.tensor([1])]* len(self._onnx_training_inputs_desc)
result += [torch.from_numpy(grad) for grad in grad_weights]
return tuple(result)
return _ORTModuleFunction.apply(*self._convert_forward_input_to_list(*inputs, **kwargs))
def _convert_forward_input_to_list(self, *inputs, **kwargs):
'''Creates forward `*inputs` list from user input and PyTorch initializers
TODO: **kwargs is not supported
ONNX Runtime forward requires an order list of:
* User input: computed from forward InferenceSession
* Initializers: computed from original PyTorch model parameters
'''
# List containing both user inputs and initializers, in this order
result = []
# Inputs
for idx, input_data in enumerate(self._forward_session.get_inputs()):
result.append(inputs[idx])
# Initializers
for idx, param in enumerate(self._original_module.named_parameters()):
result.append(param[1])
return result
def _convert_dict_torch_to_numpy(self, tensor_dict):
'''Convert `tensor_dict` PyTorch tensors to numpy tensors
This is a ONNX Runtime requirement
TODO: #UseIOBinding
'''
result = {}
for k,v in tensor_dict.items():
result.update({k : v.detach().cpu().numpy()})
return result
def _convert_forward_input_list_to_dict(self, *inputs):
'''Convert forward `*inputs` list to dict
TODO: #ImproveGraphSplitting
Additionally, a list of gradient names of initializers are created to be used by backprop
TODO: Input gradient is being ignored for MVP
'''
# Dictionary containing both inputs and initializers
result = {}
# Inputs
result_len = 0
for idx, input_data in enumerate(self._forward_session.get_inputs()):
result_len += 1
result.update({input_data.name: inputs[idx]})
# Initializers
for param in self._original_module.named_parameters():
result.update({param[0]: inputs[result_len]})
# TODO: Create order list of input grads to use during backward.
# (for scenarios where gradients of input is required - not covered on MVP)
# if len(self._original_module_forward_input_grads) < len(self._onnx_training_inputs_desc):
# self._original_module_forward_input_grads.append(param[0]+'_grad')
# TODO: Create order list of initializer grads to use during backward.
# if len(self._original_module_forward_input_grads) < len(self._onnx_backward_outputs_desc) + len(self._onnx_training_inputs_desc):
if len(self._original_module_forward_input_grads) < len(self._onnx_backward_outputs_desc):
self._original_module_forward_input_grads.append(param[0]+'_grad')
result_len += 1
return result
def _convert_backward_input_list_to_dict(self, *inputs):
'''Convert backward `*inputs` list to dict
ONNX Runtime backend requires dict as input, which is composed of:
* User input
Although not necessary, all user inputs are used for simplicity
* (Partial) Initializers
init_begin = len(user_input)
init_count = len(Pre-computed list of initializer)
* Intermediate tensors TODO: #ImproveGraphSplitting
Intermediate tensors are inferred from input position:
interm_begin = len(user_input) + len(initializer)
interm_count = len(all_inputs) - len(user_input) - len(initializer) - len(grad_output)
* Gradient wrt outputs TODO: #ImproveGraphSplitting
Gradient tensors are inferred from input position:
grads_begin = len(user_input) + len(initializer) + len(intermediate)
grads_count = len(all_inputs) - len(user_input) - len(initializer) - len(intermediate)
'''
# Dictionary containing both inputs and initializers
result = {}
# Inputs
result_len = 0
for idx, input_data in enumerate(self._forward_session.get_inputs()):
result.update({ input_data.name : inputs[idx]})
result_len += 1
# Initializers
for initializer in self._onnx_backward_initializers_desc:
result.update({initializer['name']: inputs[result_len]})
result_len += 1
# Intermediate
intermediate_len = len(inputs) - result_len - self._original_module_grad_output_len
for idx in range(intermediate_len):
result.update({self._onnx_forward_intermediate_outputs_desc[idx]['name']: inputs[result_len]})
result_len += 1
# Grad outputs
for idx in range(len(inputs)-result_len):
result.update({self._onnx_backward_gradient_inputs_desc[idx]['name']: inputs[result_len]})
result_len += 1
return result
def _run_forward_graph(self, inputs):
'''Execute forward pass on ONNX Runtime
Output order has to be specified to ONNX Runtime backend
to distinguish intermediate from output tensors
'''
output_names = [out['name'] for out in self._onnx_forward_outputs_desc]
forward_output = self._forward_session.run(output_names, inputs)
output = forward_output[:len(self._onnx_training_outputs_desc)]
intermediates = forward_output[len(self._onnx_training_outputs_desc):]
return output, intermediates
def _run_backward_graph(self, *inputs, **kwargs):
'''Execute backward pass on ONNX Runtime
`*inputs` is converted from list to a list of detached numpy tensors before
being fed to an ONNX Runtime InferenceSession
TODO: **kwargs are not supported
'''
# Convert input to dict of torch tensors
data = self._convert_backward_input_list_to_dict(*inputs)
# Convert dict of torch tensors to dict of numpy arrays (ORT BE requirement)
data = self._convert_dict_torch_to_numpy(data)
return self._backward_session.run(self._original_module_forward_input_grads, data)
@staticmethod
def _get_forward_graph(module, *inputs, **kwargs):
'''Exports PyTorch `module` to ONNX with training flag, using `*inputs` as input
TODO: How to support dynamic axes? Dimensions are determined by samples
TODO: How to ingest **kwargs in proper order during export?
'''
# Export the model to memory
f = io.BytesIO()
# Deepcopy inputs, since input values may change after model run.
sample_inputs_copy = copy.deepcopy(inputs)
# TODO: Support contrib OPs support? user model has no hint
# from onnxruntime.training import register_custom_ops_pytorch_exporter
# register_custom_ops_pytorch_exporter.register_custom_op()
# Export torch.nn.Module to ONNX
torch.onnx.export(module,
tuple(sample_inputs_copy),
f,
opset_version=ONNX_OPSET_VERSION,
do_constant_folding=False,
training=torch.onnx.TrainingMode.TRAINING)
return onnx.load_model_from_string(f.getvalue())
def _get_initializer_from_graph(self, graph):
'''Returns a descriptor list of initializers for `graph`
The list descriptor has the following format:
[{ 'name': name, 'shape':[int1,...,intN], 'dtype': <onnx.dtype> ]}]
For ONNX types, refer to https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto#L461
'''
# TODO: There is a tradeoff between memory footprint and total model export time
# Ideally we want to export the model using torch.onnx.export(.., export_params=False, keep_initializers_as_inputs=True)
# to obtain an ONNX model with minimal size and initializers as input.
# However, this results in (guessing) assuming only initializer's name end with '.weight' and '.bias'.
# Otherwise, it is not possible to separate input from initializer after the model is exported
# Options are:
# 1) If memory footprint is more important, we can export ONNX twice, varying keep_initializers_as_inputs flag
# ONNX model is small (400 bytes vs 1.6MB for MNIST), but export takes twice the time
# 2) If total export time is more important, we can export ONNX once, using export_params=True
# ONNX model is bigger, but export takes half the time
# As performance is not the main goal in this first deliverable, using approach 2) for simplicity
initializers = []
for initializer in graph.graph.initializer:
name = initializer.name
shape = initializer.dims
dtype = _utils.dtype_onnx_to_torch(initializer.data_type)
initializers.append({'name': name, 'shape': shape, 'dtype': dtype})
return initializers
def _get_input_from_graph(self, graph, initializers_only=False, append_initializers=[]):
'''Returns a descriptor list of input tensors for an ONNX `graph`
When `initializers_only=True`, only input initializers are returned. Otherwise, both
user input and initializers are considered.
This is being used to get backward initializer list TODO: #ImproveGraphSplitting
When `append_initializers` is not empty, this list is appended to the end of the result list
This is being used to get backward input list TODO: #ImproveGraphSplitting
The list descriptor has the following format:
[{ 'name': name, 'shape':[int1,...,intN], 'dtype': <onnx.dtype> ]}]
For ONNX types, refer to https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto#L461
'''
inputs = []
for elem in graph.graph.input:
for initializer in self._onnx_forward_initializers_desc:
if elem.name == initializer['name']:
if initializers_only:
name = elem.name
shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim]
dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type)
inputs.append({'name': name, 'shape': shape, 'dtype': dtype})
break
else:
if not initializers_only:
name = elem.name
shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim]
dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type)
inputs.append({'name': name, 'shape': shape, 'dtype': dtype})
if append_initializers:
inputs.extend(append_initializers)
return inputs
def _get_gradient_input_from_graph(self, backward_graph, forward_input, forward_initializer, forward_intermediate):
'''Returns a descriptor list of gradient output for `backward_graph`
Gradient output tensors are found through an elimination process, that cross reference
inputs from the backward graph to the forward input, initializer and intermediate tensors.
The list descriptor has the following format:
[{ 'name': name, 'shape':[int1,...,intN], 'dtype': <onnx.dtype> ]}]
For ONNX types, refer to https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto#L461
TODO: #ImproveGraphSplitting
'''
grads = []
found = False
for elem in backward_graph.graph.input:
for item in forward_input:
if elem.name == item['name']:
# skip output
break
else:
for item in forward_initializer:
if elem.name == item['name']:
# skip output
break
else:
for item in forward_intermediate:
if elem.name == item['name']:
# skip output
break
else:
name = elem.name
shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim]
dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type)
grads.append({'name': name, 'shape': shape, 'dtype': dtype})
return grads
def _get_output_from_graph(self, graph):
'''Returns a descriptor list of output tensors for an ONNX `graph`
The list descriptor has the following format:
[{ 'name': name, 'shape':[int1,...,intN], 'dtype': <onnx.dtype> ]}]
For ONNX types, refer to https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto#L461
'''
outputs = []
for elem in graph.graph.output:
for initializer in self._onnx_forward_initializers_desc:
if elem.name == initializer['name']:
# skip initializers
break
else:
name = elem.name
shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim]
dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type)
outputs.append({'name': name, 'shape': shape, 'dtype': dtype})
return outputs
def _get_intermediate_from_forward_graph(self, forward_graph):
'''Returns a descriptor list with all intermediate tensors for `forward_graph`
Intermediate tensors are found through an elimination process, that cross reference
outputs from the forward graph to the original model (exported to ONNX)
The list descriptor has the following format:
[{ 'name': name, 'shape':[int1,...,intN], 'dtype': <onnx.dtype> ]}]
TODO: #ImproveGraphSplitting
'''
intermediates = []
for elem in forward_graph.graph.output:
for output in self._onnx_training_outputs_desc:
if elem.name == output['name']:
# skip output
break
else:
name = elem.name
shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim]
dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type)
intermediates.append({'name': name, 'shape': shape, 'dtype': dtype})
return intermediates
@staticmethod
def _build_gradient_graph(forward_graph, config):
'''Adds gradient nodes on top of an existing ONNX graph (with training flag)
TODO: #SplittingGraphAtFrontend
'''
if not config.weight_names_to_train:
weight_names_to_train = set()
for initializer in forward_graph.graph.initializer:
weight_names_to_train.add(initializer.name)
config.weight_names_to_train = weight_names_to_train
output_names = set()
for output in forward_graph.graph.output:
output_names.add(output.name)
config.output_names = output_names
return onnx.load_model_from_string(C.ModuleGradientGraphBuilder().build(forward_graph.SerializeToString(), config))
@staticmethod
def _split_forward_and_backward(onnx_model, weight_names_to_train):
'''Splits the result of _build_gradient_graph into two subgraphs
* A forward graph that takes module inputs and weights as input, and produces module
outputs and (“stashed”) intermediate tensors as output.
* A backward graph that takes input, intermediate tensors, module weights, and gradients
with respect to the module outputs as inputs, and produces gradients with respect to the
module inputs and weights.
TODO: #SplittingGraphAtFrontend
'''
def remove_nodes(onnx_model, nodes_to_remove):
all_nodes = []
for node in onnx_model.graph.node:
if node not in nodes_to_remove:
all_nodes.append(node)
onnx_model.graph.ClearField('node')
onnx_model.graph.node.extend(all_nodes)
def add_output(model, name, data_type = None, docstring = None):
new_output = model.graph.value_info.add()
new_output.name = name
if data_type:
new_output.type.CopyFrom(data_type)
if docstring:
new_output.doc_string = docstring
model.graph.output.append(new_output)
def add_input_from_initializer(model, initializer, docstring=None):
new_input = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims, docstring)
model.graph.input.append(new_input)
def add_input(model, name, data_type = None, dims = None, docstring = None):
new_input = onnx.helper.make_tensor_value_info(name, data_type, dims, docstring)
model.graph.input.append(new_input)
forward_graph_outputs = set()
backward_graph_inputs = set()
backward_graph_outputs = set()
# Get forward graph
forward_model = copy.deepcopy(onnx_model)
nodes_to_remove_from_forward_graph = []
initializers = {}
for initializer in forward_model.graph.initializer:
initializers[initializer.name] = initializer
forward_graph_initializer_names = set()
for node in forward_model.graph.node:
if node.doc_string == 'Backward pass':
# nodes belongs to backward graph
nodes_to_remove_from_forward_graph.append(node)
for input in node.input:
backward_graph_inputs.add(input)
for output in node.output:
backward_graph_outputs.add(output)
else:
# nodes belogs to forward graph
for input in node.input:
if input in initializers:
forward_graph_initializer_names.add(input)
for output in node.output:
forward_graph_outputs.add(output)
forward_model.graph.ClearField('initializer')
for initializer_name in forward_graph_initializer_names:
forward_model.graph.initializer.append(initializers[initializer_name])
# training weights need to be added to input
if initializer_name in weight_names_to_train:
add_input_from_initializer(forward_model, initializers[initializer_name])
# outputs from forward graph that are also inputs of backwoard graph need to be added as graph output.
for output in forward_graph_outputs:
if output in backward_graph_inputs:
add_output(forward_model, output)
remove_nodes(forward_model, nodes_to_remove_from_forward_graph)
# Get backward graph
tensor_elem_types = {}
infered_model = onnx.shape_inference.infer_shapes(onnx_model)
for value_info in infered_model.graph.value_info:
tensor_elem_types[value_info.name] = value_info.type.tensor_type.elem_type
backward_model = copy.deepcopy(onnx_model)
initializers = {}
for initializer in backward_model.graph.initializer:
initializers[initializer.name] = initializer
nodes_to_remove_from_backward_graph = []
for node in backward_model.graph.node:
if node.doc_string != 'Backward pass':
nodes_to_remove_from_backward_graph.append(node)
backward_graph_initializer_names = set()
for input in backward_graph_inputs:
if input in forward_graph_outputs:
# inputs of backward graph that are also outputs from forward graph need to be added to backward graph input
# TODO: thiagofc: BERT: Remove this once graph splitter can handle unspecified optional input (without type)
input_type = tensor_elem_types[input] if input in tensor_elem_types else 1
if input in {'1835', '1813', '1781','1760', '1683','1651','1630','1553','1521','1500','1423','1391','1370','1293','1261','1240','1163','1131','1110','1033','1001','980','871',
'267','330','351','383','460','481','513','590','611','643','720','741','773','850','903'}:
input_type = 9
add_input(backward_model, input, input_type)
elif input in forward_graph_initializer_names:
# inputs from forward graph initializers need to be added to backward graph input
add_input_from_initializer(backward_model, initializers[input])
elif input in initializers:
backward_graph_initializer_names.add(input)
# gradient of forward graph output will be the input of backward graph
for output in backward_model.graph.output:
if output.name + '_grad' in backward_graph_inputs:
add_input(backward_model, output.name + '_grad', output.type.tensor_type.elem_type)
backward_model.graph.ClearField('initializer')
for initializer_name in backward_graph_initializer_names:
backward_model.graph.initializer.append(initializers[initializer_name])
# add gradient output to backward graph output
# TODO: need to add gradient of graph input to backward graph output
new_backward_graph_outputs = set()
for output in backward_graph_outputs:
if output.endswith('_grad') and output[:-5] in forward_graph_initializer_names:
new_backward_graph_outputs.add(output)
backward_model.graph.ClearField('output')
for output in new_backward_graph_outputs:
add_output(backward_model, output)
remove_nodes(backward_model, nodes_to_remove_from_backward_graph)
return forward_model, backward_model