Remove (unnecessary) gradient graph from frontend

This commit is contained in:
Thiago Crepaldi 2020-11-19 13:36:53 -08:00
parent 4d9267e102
commit 395e082bc3

View file

@ -66,7 +66,6 @@ class ORTModule(torch.nn.Module):
# User module is wrapped to use its initializers and save computed gradients
self._original_module = module
self._onnx_training = None
self._onnx_gradient = None
# Forward pass
self._onnx_forward = None
@ -158,18 +157,17 @@ class ORTModule(torch.nn.Module):
grad_builder_config = C.ModuleGradientGraphBuilderConfiguration()
# TODO: PyTorch exporter bug: changes the initializer order
initializer_names = [p[0] for p in self._original_module.named_parameters()]
self._onnx_gradient, self._onnx_forward, self._onnx_backward, self._onnx_graphs_info = \
ORTModule._build_fw_bw_grad_graphs(self._onnx_training, grad_builder_config, initializer_names)
onnx_gradient, self._onnx_forward, self._onnx_backward, self._onnx_graphs_info = \
ORTModule._build_fw_bw_grad_graphs(self._onnx_training, grad_builder_config,
initializer_names,
self._save_onnx)
if self._save_onnx:
onnx.save(self._onnx_training, self._save_onnx_prefix + '_full_training.onnx')
onnx.save(self._onnx_gradient, self._save_onnx_prefix + '_with_grad.onnx')
onnx.save(onnx_gradient, self._save_onnx_prefix + '_with_grad.onnx')
onnx.save(self._onnx_forward, self._save_onnx_prefix + '_forward.onnx')
onnx.save(self._onnx_backward, self._save_onnx_prefix + '_backward.onnx')
# TODO: Consider moving this to the backend. We don't want to append '_grad' to get correct tensor names
self._onnx_graphs_types = ORTModule._get_io_info_from_onnx_graph(self._onnx_forward, self._onnx_graphs_info)
self._forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString())
self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString())
@ -381,7 +379,7 @@ class ORTModule(torch.nn.Module):
@staticmethod
def _build_fw_bw_grad_graphs(forward_graph, config, initializer_names=[]):
def _build_fw_bw_grad_graphs(forward_graph, config, initializer_names=[], include_gradient_model=False):
'''Adds gradient nodes on top of an existing ONNX graph (with training flag)'''
if not config.initializer_names_to_train:
if not initializer_names:
@ -402,7 +400,9 @@ class ORTModule(torch.nn.Module):
module_gradient_graph_builder.build_and_split(forward_graph.SerializeToString(), config)
forward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_forward_model())
backward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_backward_model())
gradient_model = onnx.load_model_from_string(module_gradient_graph_builder.get_gradient_model())
gradient_model = None
if include_gradient_model:
gradient_model = onnx.load_model_from_string(module_gradient_graph_builder.get_gradient_model())
split_graphs_info = module_gradient_graph_builder.get_split_graphs_info()
return gradient_model, forward_model, backward_model, split_graphs_info