From 2a74f5e85baedb9b5e52a601d14e0e8974a4776d Mon Sep 17 00:00:00 2001 From: Sherlock Date: Thu, 10 Jun 2021 09:56:35 -0700 Subject: [PATCH] Save module output for backward if needed (#8010) * Save module output for backward if needed --- .../core/framework/ortmodule_graph_builder.cc | 43 +++++++++++++++++++ .../core/framework/ortmodule_graph_builder.h | 6 +++ .../python/orttraining_pybind_state.cc | 1 + .../training/ortmodule/_training_manager.py | 11 +++++ .../python/orttraining_test_ortmodule_api.py | 36 ++++++++++++++++ 5 files changed, 97 insertions(+) diff --git a/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc b/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc index e252313c20..fbe4416722 100644 --- a/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc +++ b/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc @@ -95,6 +95,9 @@ Status OrtModuleGraphBuilder::Build(const std::vector>* inp // Reorder outputs. ReorderOutputs(); + // Find module outputs needed for backward computation + FindModuleOutputNeededForBackward(); + return Status::OK(); } @@ -326,5 +329,45 @@ void OrtModuleGraphBuilder::ReorderOutputs() { gradient_graph.SetOutputs(new_output_args); } +void OrtModuleGraphBuilder::FindModuleOutputNeededForBackward() { + Graph& gradient_graph = gradient_model_->MainGraph(); + gradient_graph.Resolve(); + GraphViewer gradient_graph_viewer(gradient_graph); + const auto& exec_order = gradient_graph_viewer.GetNodesInTopologicalOrder(); + + size_t yield_node_order = 0; + bool yield_node_found = false; + std::unordered_map id_to_exec_order; + for (size_t i = 0; i < exec_order.size(); ++i) { + if (gradient_graph_viewer.GetNode(exec_order[i])->OpType() == "YieldOp") { + yield_node_order = i; + yield_node_found = true; + } + id_to_exec_order.insert({exec_order[i], i}); + } + ORT_ENFORCE(yield_node_found, "YieldOp is not found in the training graph"); + + const Node* yield_node = gradient_graph_viewer.GetNode(exec_order[yield_node_order]); + auto yield_input_node_args = yield_node->InputDefs(); + + for (size_t i = 0; i < yield_input_node_args.size(); ++i) { + const NodeArg* yield_input = yield_input_node_args[i]; + + const Node* producer_node = gradient_graph.GetProducerNode(yield_input->Name()); + if (producer_node->OpType() == "Identity") { + yield_input = producer_node->InputDefs()[0]; + } + + std::vector consumer_nodes = gradient_graph.GetConsumerNodes(yield_input->Name()); + for (const Node* n : consumer_nodes) { + // If a module output has a consumer that is executed after the YieldOp, marked it needed for backward + if (id_to_exec_order[n->Index()] > yield_node_order) { + graph_info_.module_output_indices_requires_save_for_backward.emplace_back(i); + break; + } + } + } +} + } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/framework/ortmodule_graph_builder.h b/orttraining/orttraining/core/framework/ortmodule_graph_builder.h index 9925be3a20..ed4cb23444 100644 --- a/orttraining/orttraining/core/framework/ortmodule_graph_builder.h +++ b/orttraining/orttraining/core/framework/ortmodule_graph_builder.h @@ -56,6 +56,9 @@ struct GraphInfo { // Indices of output grads that need to be materialized to full size all-0 tensor. // Otherwise, we can use scalar-0 tensor. std::vector output_grad_indices_require_full_shape{}; + // Indices of module output that are needed for backward computation + std::vector module_output_indices_requires_save_for_backward{}; + // Names of module outputs' gradient std::vector module_output_gradient_name{}; }; @@ -111,6 +114,9 @@ class OrtModuleGraphBuilder { // Reorder gradient graph outputs. void ReorderOutputs(); + // Find the module output that are needed for backward computation + void FindModuleOutputNeededForBackward(); + std::shared_ptr model_; std::shared_ptr inference_optimized_model_; std::shared_ptr gradient_model_; diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 5601e5aa92..34dd19335c 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -615,6 +615,7 @@ void addObjectMethodsForTraining(py::module& m) { .def_readwrite("user_output_names", &GraphInfo::user_output_names) .def_readwrite("output_grad_indices_non_differentiable", &GraphInfo::output_grad_indices_non_differentiable) .def_readwrite("output_grad_indices_require_full_shape", &GraphInfo::output_grad_indices_require_full_shape) + .def_readwrite("module_output_indices_requires_save_for_backward", &GraphInfo::module_output_indices_requires_save_for_backward) .def_readwrite("module_output_gradient_name", &GraphInfo::module_output_gradient_name); py::class_ ortmodule_graph_builder(m, "OrtModuleGraphBuilder"); diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 462000d580..23395e796c 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -115,6 +115,14 @@ class TrainingManager(GraphExecutionManager): # converted to a tensor filled with zeros prior to calling backward. # Save shape, device and type info to ctx for materializing tensor in backward if output grad is None. ctx.set_materialize_grads(False) + + # Mark the outputs tensors needed in backward computation + # ORT is NOT relying on save_for_backward() to actually save the tensor, + # as this tensor is also kept in ORT's PartialGraphState + # This call is to invoke pytorch's version check to detect the potential inplace corruption + for idx in self._graph_info.module_output_indices_requires_save_for_backward: + ctx.save_for_backward(user_outputs[idx]) + return user_outputs @staticmethod @@ -124,6 +132,9 @@ class TrainingManager(GraphExecutionManager): assert ctx.run_info is not None, 'forward() or __call__() methods must be called before backward()' _utils._check_same_device(self._device, "Input argument to backward", *grad_outputs) + # Unpack saved_tensor to trigger version detection that catches inplace corruption + _ = ctx.saved_tensors + # Use IO binding # Push user output grads to ONNX backend. backward_inputs = C.OrtValueVector() diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index e3e6ca5727..9a8a387142 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1079,6 +1079,42 @@ def test_input_requires_grad_backward_creates_input_grad_as_required0(device): # backward() is from y2, so grad of fc1.weight and fc1.bias will not be calculated. _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, none_pt_params=['fc1.weight', 'fc1.bias']) + +@pytest.mark.parametrize("device", ['cuda']) +def test_model_output_with_inplace_update(device): + class NeuralNetWithGradNeedOutput(torch.nn.Module): + def __init__(self, input_size, hidden_size): + super(NeuralNetWithGradNeedOutput, self).__init__() + self.fc1_1 = torch.nn.Linear(input_size, hidden_size) + # Softmax's gradient is depending on its output + self.act = torch.nn.Softmax(dim=1) + + def forward(self, input1): + out1 = self.act(self.fc1_1(input1)) + return out1 + + def run_step(model, x1): + y1 = model(x1) + y1.add_(1) # inplace update to module output + y1 = y1.sum() + y1.backward() + return y1 + + N, D_in, H = 32, 784, 500 + pt_model = NeuralNetWithGradNeedOutput(D_in, H).to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + pt_x1 = torch.randn(N, D_in, device=device, requires_grad=True) + ort_x1 = pt_x1.clone() + + with pytest.raises(Exception) as ex_info: + pt_y1 = run_step(pt_model, pt_x1) + assert "modified by an inplace operation" in str(ex_info.value) + + with pytest.raises(Exception) as ex_info: + ort_y1 = run_step(ort_model, ort_x1) + assert "modified by an inplace operation" in str(ex_info.value) + @pytest.mark.parametrize("device", ['cuda']) def test_loss_combines_two_outputs_with_dependency(device):