mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Save module output for backward if needed (#8010)
* Save module output for backward if needed
This commit is contained in:
parent
c74265667e
commit
2a74f5e85b
5 changed files with 97 additions and 0 deletions
|
|
@ -95,6 +95,9 @@ Status OrtModuleGraphBuilder::Build(const std::vector<std::vector<int64_t>>* inp
|
|||
// Reorder outputs.
|
||||
ReorderOutputs();
|
||||
|
||||
// Find module outputs needed for backward computation
|
||||
FindModuleOutputNeededForBackward();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -326,5 +329,45 @@ void OrtModuleGraphBuilder::ReorderOutputs() {
|
|||
gradient_graph.SetOutputs(new_output_args);
|
||||
}
|
||||
|
||||
void OrtModuleGraphBuilder::FindModuleOutputNeededForBackward() {
|
||||
Graph& gradient_graph = gradient_model_->MainGraph();
|
||||
gradient_graph.Resolve();
|
||||
GraphViewer gradient_graph_viewer(gradient_graph);
|
||||
const auto& exec_order = gradient_graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
size_t yield_node_order = 0;
|
||||
bool yield_node_found = false;
|
||||
std::unordered_map<NodeIndex, size_t> id_to_exec_order;
|
||||
for (size_t i = 0; i < exec_order.size(); ++i) {
|
||||
if (gradient_graph_viewer.GetNode(exec_order[i])->OpType() == "YieldOp") {
|
||||
yield_node_order = i;
|
||||
yield_node_found = true;
|
||||
}
|
||||
id_to_exec_order.insert({exec_order[i], i});
|
||||
}
|
||||
ORT_ENFORCE(yield_node_found, "YieldOp is not found in the training graph");
|
||||
|
||||
const Node* yield_node = gradient_graph_viewer.GetNode(exec_order[yield_node_order]);
|
||||
auto yield_input_node_args = yield_node->InputDefs();
|
||||
|
||||
for (size_t i = 0; i < yield_input_node_args.size(); ++i) {
|
||||
const NodeArg* yield_input = yield_input_node_args[i];
|
||||
|
||||
const Node* producer_node = gradient_graph.GetProducerNode(yield_input->Name());
|
||||
if (producer_node->OpType() == "Identity") {
|
||||
yield_input = producer_node->InputDefs()[0];
|
||||
}
|
||||
|
||||
std::vector<const Node*> consumer_nodes = gradient_graph.GetConsumerNodes(yield_input->Name());
|
||||
for (const Node* n : consumer_nodes) {
|
||||
// If a module output has a consumer that is executed after the YieldOp, marked it needed for backward
|
||||
if (id_to_exec_order[n->Index()] > yield_node_order) {
|
||||
graph_info_.module_output_indices_requires_save_for_backward.emplace_back(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -56,6 +56,9 @@ struct GraphInfo {
|
|||
// Indices of output grads that need to be materialized to full size all-0 tensor.
|
||||
// Otherwise, we can use scalar-0 tensor.
|
||||
std::vector<size_t> output_grad_indices_require_full_shape{};
|
||||
// Indices of module output that are needed for backward computation
|
||||
std::vector<size_t> module_output_indices_requires_save_for_backward{};
|
||||
// Names of module outputs' gradient
|
||||
std::vector<std::string> module_output_gradient_name{};
|
||||
};
|
||||
|
||||
|
|
@ -111,6 +114,9 @@ class OrtModuleGraphBuilder {
|
|||
// Reorder gradient graph outputs.
|
||||
void ReorderOutputs();
|
||||
|
||||
// Find the module output that are needed for backward computation
|
||||
void FindModuleOutputNeededForBackward();
|
||||
|
||||
std::shared_ptr<onnxruntime::Model> model_;
|
||||
std::shared_ptr<onnxruntime::Model> inference_optimized_model_;
|
||||
std::shared_ptr<onnxruntime::Model> gradient_model_;
|
||||
|
|
|
|||
|
|
@ -615,6 +615,7 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
.def_readwrite("user_output_names", &GraphInfo::user_output_names)
|
||||
.def_readwrite("output_grad_indices_non_differentiable", &GraphInfo::output_grad_indices_non_differentiable)
|
||||
.def_readwrite("output_grad_indices_require_full_shape", &GraphInfo::output_grad_indices_require_full_shape)
|
||||
.def_readwrite("module_output_indices_requires_save_for_backward", &GraphInfo::module_output_indices_requires_save_for_backward)
|
||||
.def_readwrite("module_output_gradient_name", &GraphInfo::module_output_gradient_name);
|
||||
|
||||
py::class_<OrtModuleGraphBuilder> ortmodule_graph_builder(m, "OrtModuleGraphBuilder");
|
||||
|
|
|
|||
|
|
@ -115,6 +115,14 @@ class TrainingManager(GraphExecutionManager):
|
|||
# converted to a tensor filled with zeros prior to calling backward.
|
||||
# Save shape, device and type info to ctx for materializing tensor in backward if output grad is None.
|
||||
ctx.set_materialize_grads(False)
|
||||
|
||||
# Mark the outputs tensors needed in backward computation
|
||||
# ORT is NOT relying on save_for_backward() to actually save the tensor,
|
||||
# as this tensor is also kept in ORT's PartialGraphState
|
||||
# This call is to invoke pytorch's version check to detect the potential inplace corruption
|
||||
for idx in self._graph_info.module_output_indices_requires_save_for_backward:
|
||||
ctx.save_for_backward(user_outputs[idx])
|
||||
|
||||
return user_outputs
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -124,6 +132,9 @@ class TrainingManager(GraphExecutionManager):
|
|||
assert ctx.run_info is not None, 'forward() or __call__() methods must be called before backward()'
|
||||
_utils._check_same_device(self._device, "Input argument to backward", *grad_outputs)
|
||||
|
||||
# Unpack saved_tensor to trigger version detection that catches inplace corruption
|
||||
_ = ctx.saved_tensors
|
||||
|
||||
# Use IO binding
|
||||
# Push user output grads to ONNX backend.
|
||||
backward_inputs = C.OrtValueVector()
|
||||
|
|
|
|||
|
|
@ -1079,6 +1079,42 @@ def test_input_requires_grad_backward_creates_input_grad_as_required0(device):
|
|||
# backward() is from y2, so grad of fc1.weight and fc1.bias will not be calculated.
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, none_pt_params=['fc1.weight', 'fc1.bias'])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ['cuda'])
|
||||
def test_model_output_with_inplace_update(device):
|
||||
class NeuralNetWithGradNeedOutput(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size):
|
||||
super(NeuralNetWithGradNeedOutput, self).__init__()
|
||||
self.fc1_1 = torch.nn.Linear(input_size, hidden_size)
|
||||
# Softmax's gradient is depending on its output
|
||||
self.act = torch.nn.Softmax(dim=1)
|
||||
|
||||
def forward(self, input1):
|
||||
out1 = self.act(self.fc1_1(input1))
|
||||
return out1
|
||||
|
||||
def run_step(model, x1):
|
||||
y1 = model(x1)
|
||||
y1.add_(1) # inplace update to module output
|
||||
y1 = y1.sum()
|
||||
y1.backward()
|
||||
return y1
|
||||
|
||||
N, D_in, H = 32, 784, 500
|
||||
pt_model = NeuralNetWithGradNeedOutput(D_in, H).to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
pt_x1 = torch.randn(N, D_in, device=device, requires_grad=True)
|
||||
ort_x1 = pt_x1.clone()
|
||||
|
||||
with pytest.raises(Exception) as ex_info:
|
||||
pt_y1 = run_step(pt_model, pt_x1)
|
||||
assert "modified by an inplace operation" in str(ex_info.value)
|
||||
|
||||
with pytest.raises(Exception) as ex_info:
|
||||
ort_y1 = run_step(ort_model, ort_x1)
|
||||
assert "modified by an inplace operation" in str(ex_info.value)
|
||||
|
||||
@pytest.mark.parametrize("device", ['cuda'])
|
||||
def test_loss_combines_two_outputs_with_dependency(device):
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue