Save module output for backward if needed (#8010)

* Save module output for backward if needed
This commit is contained in:
Sherlock 2021-06-10 09:56:35 -07:00 committed by GitHub
parent c74265667e
commit 2a74f5e85b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 97 additions and 0 deletions

View file

@ -95,6 +95,9 @@ Status OrtModuleGraphBuilder::Build(const std::vector<std::vector<int64_t>>* inp
// Reorder outputs.
ReorderOutputs();
// Find module outputs needed for backward computation
FindModuleOutputNeededForBackward();
return Status::OK();
}
@ -326,5 +329,45 @@ void OrtModuleGraphBuilder::ReorderOutputs() {
gradient_graph.SetOutputs(new_output_args);
}
void OrtModuleGraphBuilder::FindModuleOutputNeededForBackward() {
Graph& gradient_graph = gradient_model_->MainGraph();
gradient_graph.Resolve();
GraphViewer gradient_graph_viewer(gradient_graph);
const auto& exec_order = gradient_graph_viewer.GetNodesInTopologicalOrder();
size_t yield_node_order = 0;
bool yield_node_found = false;
std::unordered_map<NodeIndex, size_t> id_to_exec_order;
for (size_t i = 0; i < exec_order.size(); ++i) {
if (gradient_graph_viewer.GetNode(exec_order[i])->OpType() == "YieldOp") {
yield_node_order = i;
yield_node_found = true;
}
id_to_exec_order.insert({exec_order[i], i});
}
ORT_ENFORCE(yield_node_found, "YieldOp is not found in the training graph");
const Node* yield_node = gradient_graph_viewer.GetNode(exec_order[yield_node_order]);
auto yield_input_node_args = yield_node->InputDefs();
for (size_t i = 0; i < yield_input_node_args.size(); ++i) {
const NodeArg* yield_input = yield_input_node_args[i];
const Node* producer_node = gradient_graph.GetProducerNode(yield_input->Name());
if (producer_node->OpType() == "Identity") {
yield_input = producer_node->InputDefs()[0];
}
std::vector<const Node*> consumer_nodes = gradient_graph.GetConsumerNodes(yield_input->Name());
for (const Node* n : consumer_nodes) {
// If a module output has a consumer that is executed after the YieldOp, marked it needed for backward
if (id_to_exec_order[n->Index()] > yield_node_order) {
graph_info_.module_output_indices_requires_save_for_backward.emplace_back(i);
break;
}
}
}
}
} // namespace training
} // namespace onnxruntime

View file

@ -56,6 +56,9 @@ struct GraphInfo {
// Indices of output grads that need to be materialized to full size all-0 tensor.
// Otherwise, we can use scalar-0 tensor.
std::vector<size_t> output_grad_indices_require_full_shape{};
// Indices of module output that are needed for backward computation
std::vector<size_t> module_output_indices_requires_save_for_backward{};
// Names of module outputs' gradient
std::vector<std::string> module_output_gradient_name{};
};
@ -111,6 +114,9 @@ class OrtModuleGraphBuilder {
// Reorder gradient graph outputs.
void ReorderOutputs();
// Find the module output that are needed for backward computation
void FindModuleOutputNeededForBackward();
std::shared_ptr<onnxruntime::Model> model_;
std::shared_ptr<onnxruntime::Model> inference_optimized_model_;
std::shared_ptr<onnxruntime::Model> gradient_model_;

View file

@ -615,6 +615,7 @@ void addObjectMethodsForTraining(py::module& m) {
.def_readwrite("user_output_names", &GraphInfo::user_output_names)
.def_readwrite("output_grad_indices_non_differentiable", &GraphInfo::output_grad_indices_non_differentiable)
.def_readwrite("output_grad_indices_require_full_shape", &GraphInfo::output_grad_indices_require_full_shape)
.def_readwrite("module_output_indices_requires_save_for_backward", &GraphInfo::module_output_indices_requires_save_for_backward)
.def_readwrite("module_output_gradient_name", &GraphInfo::module_output_gradient_name);
py::class_<OrtModuleGraphBuilder> ortmodule_graph_builder(m, "OrtModuleGraphBuilder");

View file

@ -115,6 +115,14 @@ class TrainingManager(GraphExecutionManager):
# converted to a tensor filled with zeros prior to calling backward.
# Save shape, device and type info to ctx for materializing tensor in backward if output grad is None.
ctx.set_materialize_grads(False)
# Mark the outputs tensors needed in backward computation
# ORT is NOT relying on save_for_backward() to actually save the tensor,
# as this tensor is also kept in ORT's PartialGraphState
# This call is to invoke pytorch's version check to detect the potential inplace corruption
for idx in self._graph_info.module_output_indices_requires_save_for_backward:
ctx.save_for_backward(user_outputs[idx])
return user_outputs
@staticmethod
@ -124,6 +132,9 @@ class TrainingManager(GraphExecutionManager):
assert ctx.run_info is not None, 'forward() or __call__() methods must be called before backward()'
_utils._check_same_device(self._device, "Input argument to backward", *grad_outputs)
# Unpack saved_tensor to trigger version detection that catches inplace corruption
_ = ctx.saved_tensors
# Use IO binding
# Push user output grads to ONNX backend.
backward_inputs = C.OrtValueVector()

View file

@ -1079,6 +1079,42 @@ def test_input_requires_grad_backward_creates_input_grad_as_required0(device):
# backward() is from y2, so grad of fc1.weight and fc1.bias will not be calculated.
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, none_pt_params=['fc1.weight', 'fc1.bias'])
@pytest.mark.parametrize("device", ['cuda'])
def test_model_output_with_inplace_update(device):
class NeuralNetWithGradNeedOutput(torch.nn.Module):
def __init__(self, input_size, hidden_size):
super(NeuralNetWithGradNeedOutput, self).__init__()
self.fc1_1 = torch.nn.Linear(input_size, hidden_size)
# Softmax's gradient is depending on its output
self.act = torch.nn.Softmax(dim=1)
def forward(self, input1):
out1 = self.act(self.fc1_1(input1))
return out1
def run_step(model, x1):
y1 = model(x1)
y1.add_(1) # inplace update to module output
y1 = y1.sum()
y1.backward()
return y1
N, D_in, H = 32, 784, 500
pt_model = NeuralNetWithGradNeedOutput(D_in, H).to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
pt_x1 = torch.randn(N, D_in, device=device, requires_grad=True)
ort_x1 = pt_x1.clone()
with pytest.raises(Exception) as ex_info:
pt_y1 = run_step(pt_model, pt_x1)
assert "modified by an inplace operation" in str(ex_info.value)
with pytest.raises(Exception) as ex_info:
ort_y1 = run_step(ort_model, ort_x1)
assert "modified by an inplace operation" in str(ex_info.value)
@pytest.mark.parametrize("device", ['cuda'])
def test_loss_combines_two_outputs_with_dependency(device):