mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
custom autograd func memory refinement (#8993)
* Release torch tensor referenced by torch gradient graph (created in PythonOp) * Update orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/torch_interop_utils/torch_interop_utils.cc * refine with comments Co-authored-by: Wei-Sheng Chin <wschin@outlook.com>
This commit is contained in:
parent
d39959172f
commit
d209fe29b9
2 changed files with 100 additions and 34 deletions
|
|
@ -81,6 +81,9 @@ def call_python_forward_function(
|
|||
def register_context(result):
|
||||
# Search for context among all outputs.
|
||||
ctx = None
|
||||
# All forward outputs of torch.autograd.Function shared a same gradient function pointer,
|
||||
# so here we just get the first tensor having grad_fn attribute.
|
||||
# (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/custom_function.cpp#L267)
|
||||
first_tensor_output = None
|
||||
for arg in result:
|
||||
if not isinstance(arg, torch.Tensor) or not hasattr(arg, 'grad_fn'):
|
||||
|
|
@ -93,6 +96,22 @@ def call_python_forward_function(
|
|||
if training_mode_flag:
|
||||
# Must extract one valid context from result tensors.
|
||||
assert ctx is not None
|
||||
|
||||
# FORWARD BACKWARD FUNCTION CONNECTIONS
|
||||
# input_1 (leaf, constructed by from_dlpack) <----reference---- AccumulateGrad gradient function
|
||||
# ↓ ↑
|
||||
# autograd.Function apply() ------------> autograd.Function backward()
|
||||
# ↓ | ↑
|
||||
# output_1, output_2 --- shared_ptr<PyNode> --- ↑
|
||||
# ↓ previous gradient function
|
||||
|
||||
# We remove the edges starting between current autograd.Function's gradient function and
|
||||
# it's input's gradient function (e.g. AccumulateGrad gradient function), then
|
||||
# AccumulateGrad gradient function will be destroyed, releasing the reference to input_1
|
||||
# (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/functions/accumulate_grad.cpp#L21).
|
||||
# The next edges are stored in Node, with which we can get next gradient function.
|
||||
# https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/function.h#L527
|
||||
torch_interop_utils.clear_grad_fns_for_next_edges(first_tensor_output, ctx.saved_tensors)
|
||||
torch_interop_utils.register_grad_fn(id(ctx), first_tensor_output)
|
||||
else:
|
||||
# Context must not present under non-training mode.
|
||||
|
|
@ -158,36 +177,37 @@ def call_python_backward_function(
|
|||
inplace: indicates if args can be modified inside the custom function.
|
||||
args: inputs to "backward_function".
|
||||
'''
|
||||
def wrap_all_outputs(result):
|
||||
if isinstance(result, torch.Tensor):
|
||||
return [to_dlpack(result)]
|
||||
elif isinstance(result, tuple) or isinstance(result, list):
|
||||
return [to_dlpack(value) if value is not None else None for value in result]
|
||||
else:
|
||||
raise wrap_exception(ORTModuleIOError,
|
||||
TypeError(f'ORTModule does not support the following model output type {type(result)}.'))
|
||||
with torch.no_grad():
|
||||
def wrap_all_outputs(result):
|
||||
if isinstance(result, torch.Tensor):
|
||||
return [to_dlpack(result)]
|
||||
elif isinstance(result, tuple) or isinstance(result, list):
|
||||
return [to_dlpack(value) if value is not None else None for value in result]
|
||||
else:
|
||||
raise wrap_exception(ORTModuleIOError,
|
||||
TypeError(f'ORTModule does not support the following model output type {type(result)}.'))
|
||||
|
||||
try:
|
||||
# Backward inputs should not require gradients.
|
||||
assert all(grad_flag == 0 for grad_flag in requires_grad_flags)
|
||||
try:
|
||||
# Backward inputs should not require gradients.
|
||||
assert all(grad_flag == 0 for grad_flag in requires_grad_flags)
|
||||
|
||||
# Prepare inputs for calling Python function.
|
||||
wrapped_args = list(wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace, is_training_mode, arg)
|
||||
for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, args))
|
||||
# Prepare inputs for calling Python function.
|
||||
wrapped_args = list(wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace, is_training_mode, arg)
|
||||
for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, args))
|
||||
|
||||
# Call Python function.
|
||||
result = backward_function(*wrapped_args)
|
||||
# Call Python function.
|
||||
result = backward_function(*wrapped_args)
|
||||
|
||||
# Extract results as DLPack tensor list.
|
||||
wrapped_returned_args = wrap_all_outputs(result)
|
||||
# Extract results as DLPack tensor list.
|
||||
wrapped_returned_args = wrap_all_outputs(result)
|
||||
|
||||
ctx = wrapped_args[0]
|
||||
torch_interop_utils.unregister_grad_fn(id(ctx))
|
||||
ctx = wrapped_args[0]
|
||||
torch_interop_utils.unregister_grad_fn(id(ctx))
|
||||
|
||||
return tuple(wrapped_returned_args)
|
||||
except Exception as e:
|
||||
# Flush buffers. Otherwise, calling this from C++ may lose them.
|
||||
print('Exception happens when running ', backward_function)
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
raise wrap_exception(ORTModuleFallbackException, e)
|
||||
return tuple(wrapped_returned_args)
|
||||
except Exception as e:
|
||||
# Flush buffers. Otherwise, calling this from C++ may lose them.
|
||||
print('Exception happens when running ', backward_function)
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
raise wrap_exception(ORTModuleFallbackException, e)
|
||||
|
|
|
|||
|
|
@ -3,16 +3,22 @@
|
|||
#include <torch/extension.h>
|
||||
#include <torch/csrc/autograd/function.h>
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <torch/csrc/autograd/functions/accumulate_grad.h>
|
||||
|
||||
// In Torch forward run (e.g. THPVariable_apply), ctx of type THPFunction* (which is also a PyObject*)
|
||||
// is created. The ctx is used to run user-defined forward function and backward function as the first
|
||||
// parameter. The same time, a cdata of type std::shared_ptr<PyNode> is created, cdata is owned by:
|
||||
// a). forward run output tensors as grad_fn_ property. (The full hierarchy is: Tensor own
|
||||
// In Torch forward run (e.g. THPFunction_apply), ctx of type THPFunction* (which is also a PyObject*)
|
||||
// is created (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L673).
|
||||
// The ctx is used to run user-defined forward function and backward function as the first
|
||||
// parameter. The same time, a cdata of type std::shared_ptr<PyNode> is created
|
||||
// (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L677),
|
||||
// cdata is owned by:
|
||||
// a). forward run output tensors as grad_fn_ property. (The full hierarchy is: Tensor owns
|
||||
// shared_pointer<TensorImpl>; TensorImpl owns std::unique_ptr<AutogradMeta>; AutogradMeta
|
||||
// manages grad_/grad_fn_/grad_accumulator_. Among them, grad_fn_ is std::shared_ptr<PyNode>,
|
||||
// the so called gradient function.)
|
||||
// b). the consumer operator of forward run outputs, will let its own PyNode/Node own the grad_fn_
|
||||
// (of type std::shared_ptr<PyNode>) of all inputs that require grad.
|
||||
// e.g, the so called gradient function.)
|
||||
// https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/variable.h#L194
|
||||
// b). the consumer operator of forward run outputs, will let its own PyNode/Node (gradident function)
|
||||
// owns the grad_fn_ (of type std::shared_ptr<PyNode>) of all inputs that require grad.
|
||||
// https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/function.h#L263
|
||||
// BUT, if we run torch computation within PythonOp, b) is lost. SO, for some cases, where forward outputs
|
||||
// are not used and freed before backward function runs, the grad_fn_ (std::shared_ptr<PyNode>) references
|
||||
// in a) will be released. Without b)'s reference, grad_fn_ release PyNode as reference count reach 0;
|
||||
|
|
@ -55,6 +61,45 @@ class PyNodeSharedPointerPool {
|
|||
};
|
||||
|
||||
|
||||
void clear_grad_fns_for_next_edges(at::Tensor target, std::vector<at::Tensor> saved_tensors) {
|
||||
// For leaf tensor, there will be a AccumulateGrad (gradident function) created, which owns a
|
||||
// reference to the tensor.
|
||||
// For any user saved tensors (with save_for_backward), if the tensor is leaf, we put the map
|
||||
// {AccumulateGrad*, Tensor*} into grad_fn_to_tensor_map.
|
||||
std::unordered_map<torch::autograd::Node*, at::Tensor*> grad_fn_to_tensor_map;
|
||||
for (auto& t: saved_tensors) {
|
||||
auto grad_fn = t.grad_fn();
|
||||
if (!grad_fn) {
|
||||
grad_fn = torch::autograd::impl::try_get_grad_accumulator(t);
|
||||
if (grad_fn) {
|
||||
TORCH_CHECK(grad_fn_to_tensor_map.find(grad_fn.get()) == grad_fn_to_tensor_map.end(),
|
||||
"found AccumulateGrad* is used by more than one tensors.");
|
||||
grad_fn_to_tensor_map.insert({grad_fn.get(), &t});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto& gradient_func_sptr = target.grad_fn();
|
||||
for (auto& edge : gradient_func_sptr->next_edges()) {
|
||||
torch::autograd::Node* node_func = edge.function.get();
|
||||
// If we find the next gradient function is AccumulateGrad, we will check whether its owned
|
||||
// tensors is in ctx.save_tensors or not. If yes, we skip it; otherwise, we clean the edge, which
|
||||
// will release the AccumulateGrad function.
|
||||
if (dynamic_cast<torch::autograd::AccumulateGrad*>(node_func)) {
|
||||
if (grad_fn_to_tensor_map.find(node_func) != grad_fn_to_tensor_map.end()) {
|
||||
// skip the edges that connect to saved_tensors. Because when unpack ctx.saved_tensors (using input, = ctx.saved_tensors) in backward,
|
||||
// there is such a check : if the saved tensor is a leaf and requires grad, it it should have grad accumulator.
|
||||
// If we clean the edge, then an exception "RuntimeError: No grad accumulator for a saved leaf!" will be thrown
|
||||
TORCH_WARN("Find a AccumulateGrad node, but skip it because the owned tensor is in saved_tensors.");
|
||||
continue;
|
||||
} else {
|
||||
TORCH_WARN("Find a AccumulateGrad node, and planned to clean the edge to it.");
|
||||
edge.function.reset();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void register_grad_fn(size_t ctx_address, at::Tensor target)
|
||||
{
|
||||
torch::autograd::AutogradMeta* autograd_meta = torch::autograd::impl::get_autograd_meta(target);
|
||||
|
|
@ -69,4 +114,5 @@ void unregister_grad_fn(size_t ctx_address)
|
|||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("register_grad_fn", ®ister_grad_fn, "increase grad_fn shared pointer reference.");
|
||||
m.def("unregister_grad_fn", &unregister_grad_fn, "release grad_fn shared pointer referece.");
|
||||
m.def("clear_grad_fns_for_next_edges", &clear_grad_fns_for_next_edges, "remove reference on next edges' gradident funtions.");
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue