mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Remove (unnecessary) gradient graph from frontend
This commit is contained in:
parent
4d9267e102
commit
395e082bc3
1 changed files with 9 additions and 9 deletions
|
|
@ -66,7 +66,6 @@ class ORTModule(torch.nn.Module):
|
|||
# User module is wrapped to use its initializers and save computed gradients
|
||||
self._original_module = module
|
||||
self._onnx_training = None
|
||||
self._onnx_gradient = None
|
||||
|
||||
# Forward pass
|
||||
self._onnx_forward = None
|
||||
|
|
@ -158,18 +157,17 @@ class ORTModule(torch.nn.Module):
|
|||
grad_builder_config = C.ModuleGradientGraphBuilderConfiguration()
|
||||
# TODO: PyTorch exporter bug: changes the initializer order
|
||||
initializer_names = [p[0] for p in self._original_module.named_parameters()]
|
||||
self._onnx_gradient, self._onnx_forward, self._onnx_backward, self._onnx_graphs_info = \
|
||||
ORTModule._build_fw_bw_grad_graphs(self._onnx_training, grad_builder_config, initializer_names)
|
||||
onnx_gradient, self._onnx_forward, self._onnx_backward, self._onnx_graphs_info = \
|
||||
ORTModule._build_fw_bw_grad_graphs(self._onnx_training, grad_builder_config,
|
||||
initializer_names,
|
||||
self._save_onnx)
|
||||
|
||||
if self._save_onnx:
|
||||
onnx.save(self._onnx_training, self._save_onnx_prefix + '_full_training.onnx')
|
||||
onnx.save(self._onnx_gradient, self._save_onnx_prefix + '_with_grad.onnx')
|
||||
onnx.save(onnx_gradient, self._save_onnx_prefix + '_with_grad.onnx')
|
||||
onnx.save(self._onnx_forward, self._save_onnx_prefix + '_forward.onnx')
|
||||
onnx.save(self._onnx_backward, self._save_onnx_prefix + '_backward.onnx')
|
||||
|
||||
# TODO: Consider moving this to the backend. We don't want to append '_grad' to get correct tensor names
|
||||
self._onnx_graphs_types = ORTModule._get_io_info_from_onnx_graph(self._onnx_forward, self._onnx_graphs_info)
|
||||
|
||||
self._forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString())
|
||||
self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString())
|
||||
|
||||
|
|
@ -381,7 +379,7 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
|
||||
@staticmethod
|
||||
def _build_fw_bw_grad_graphs(forward_graph, config, initializer_names=[]):
|
||||
def _build_fw_bw_grad_graphs(forward_graph, config, initializer_names=[], include_gradient_model=False):
|
||||
'''Adds gradient nodes on top of an existing ONNX graph (with training flag)'''
|
||||
if not config.initializer_names_to_train:
|
||||
if not initializer_names:
|
||||
|
|
@ -402,7 +400,9 @@ class ORTModule(torch.nn.Module):
|
|||
module_gradient_graph_builder.build_and_split(forward_graph.SerializeToString(), config)
|
||||
forward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_forward_model())
|
||||
backward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_backward_model())
|
||||
gradient_model = onnx.load_model_from_string(module_gradient_graph_builder.get_gradient_model())
|
||||
gradient_model = None
|
||||
if include_gradient_model:
|
||||
gradient_model = onnx.load_model_from_string(module_gradient_graph_builder.get_gradient_model())
|
||||
split_graphs_info = module_gradient_graph_builder.get_split_graphs_info()
|
||||
|
||||
return gradient_model, forward_model, backward_model, split_graphs_info
|
||||
|
|
|
|||
Loading…
Reference in a new issue