diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index 2f8654e3a2..5355348276 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -66,7 +66,6 @@ class ORTModule(torch.nn.Module): # User module is wrapped to use its initializers and save computed gradients self._original_module = module self._onnx_training = None - self._onnx_gradient = None # Forward pass self._onnx_forward = None @@ -158,18 +157,17 @@ class ORTModule(torch.nn.Module): grad_builder_config = C.ModuleGradientGraphBuilderConfiguration() # TODO: PyTorch exporter bug: changes the initializer order initializer_names = [p[0] for p in self._original_module.named_parameters()] - self._onnx_gradient, self._onnx_forward, self._onnx_backward, self._onnx_graphs_info = \ - ORTModule._build_fw_bw_grad_graphs(self._onnx_training, grad_builder_config, initializer_names) + onnx_gradient, self._onnx_forward, self._onnx_backward, self._onnx_graphs_info = \ + ORTModule._build_fw_bw_grad_graphs(self._onnx_training, grad_builder_config, + initializer_names, + self._save_onnx) if self._save_onnx: onnx.save(self._onnx_training, self._save_onnx_prefix + '_full_training.onnx') - onnx.save(self._onnx_gradient, self._save_onnx_prefix + '_with_grad.onnx') + onnx.save(onnx_gradient, self._save_onnx_prefix + '_with_grad.onnx') onnx.save(self._onnx_forward, self._save_onnx_prefix + '_forward.onnx') onnx.save(self._onnx_backward, self._save_onnx_prefix + '_backward.onnx') - # TODO: Consider moving this to the backend. We don't want to append '_grad' to get correct tensor names - self._onnx_graphs_types = ORTModule._get_io_info_from_onnx_graph(self._onnx_forward, self._onnx_graphs_info) - self._forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString()) self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString()) @@ -381,7 +379,7 @@ class ORTModule(torch.nn.Module): @staticmethod - def _build_fw_bw_grad_graphs(forward_graph, config, initializer_names=[]): + def _build_fw_bw_grad_graphs(forward_graph, config, initializer_names=[], include_gradient_model=False): '''Adds gradient nodes on top of an existing ONNX graph (with training flag)''' if not config.initializer_names_to_train: if not initializer_names: @@ -402,7 +400,9 @@ class ORTModule(torch.nn.Module): module_gradient_graph_builder.build_and_split(forward_graph.SerializeToString(), config) forward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_forward_model()) backward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_backward_model()) - gradient_model = onnx.load_model_from_string(module_gradient_graph_builder.get_gradient_model()) + gradient_model = None + if include_gradient_model: + gradient_model = onnx.load_model_from_string(module_gradient_graph_builder.get_gradient_model()) split_graphs_info = module_gradient_graph_builder.get_split_graphs_info() return gradient_model, forward_model, backward_model, split_graphs_info