custom autograd func memory refinement (#8993)

* Release torch tensor referenced by torch gradient graph (created in PythonOp)

* Update orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/torch_interop_utils/torch_interop_utils.cc

* refine with comments

Co-authored-by: Wei-Sheng Chin <wschin@outlook.com>
This commit is contained in:
pengwa 2021-09-09 18:37:24 +08:00 committed by GitHub
parent d39959172f
commit d209fe29b9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 100 additions and 34 deletions

View file

@ -81,6 +81,9 @@ def call_python_forward_function(
def register_context(result):
# Search for context among all outputs.
ctx = None
# All forward outputs of torch.autograd.Function shared a same gradient function pointer,
# so here we just get the first tensor having grad_fn attribute.
# (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/custom_function.cpp#L267)
first_tensor_output = None
for arg in result:
if not isinstance(arg, torch.Tensor) or not hasattr(arg, 'grad_fn'):
@ -93,6 +96,22 @@ def call_python_forward_function(
if training_mode_flag:
# Must extract one valid context from result tensors.
assert ctx is not None
# FORWARD BACKWARD FUNCTION CONNECTIONS
# input_1 (leaf, constructed by from_dlpack) <----reference---- AccumulateGrad gradient function
# ↓ ↑
# autograd.Function apply() ------------> autograd.Function backward()
# ↓ | ↑
# output_1, output_2 --- shared_ptr<PyNode> --- ↑
# ↓ previous gradient function
# We remove the edges starting between current autograd.Function's gradient function and
# it's input's gradient function (e.g. AccumulateGrad gradient function), then
# AccumulateGrad gradient function will be destroyed, releasing the reference to input_1
# (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/functions/accumulate_grad.cpp#L21).
# The next edges are stored in Node, with which we can get next gradient function.
# https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/function.h#L527
torch_interop_utils.clear_grad_fns_for_next_edges(first_tensor_output, ctx.saved_tensors)
torch_interop_utils.register_grad_fn(id(ctx), first_tensor_output)
else:
# Context must not present under non-training mode.
@ -158,36 +177,37 @@ def call_python_backward_function(
inplace: indicates if args can be modified inside the custom function.
args: inputs to "backward_function".
'''
def wrap_all_outputs(result):
if isinstance(result, torch.Tensor):
return [to_dlpack(result)]
elif isinstance(result, tuple) or isinstance(result, list):
return [to_dlpack(value) if value is not None else None for value in result]
else:
raise wrap_exception(ORTModuleIOError,
TypeError(f'ORTModule does not support the following model output type {type(result)}.'))
with torch.no_grad():
def wrap_all_outputs(result):
if isinstance(result, torch.Tensor):
return [to_dlpack(result)]
elif isinstance(result, tuple) or isinstance(result, list):
return [to_dlpack(value) if value is not None else None for value in result]
else:
raise wrap_exception(ORTModuleIOError,
TypeError(f'ORTModule does not support the following model output type {type(result)}.'))
try:
# Backward inputs should not require gradients.
assert all(grad_flag == 0 for grad_flag in requires_grad_flags)
try:
# Backward inputs should not require gradients.
assert all(grad_flag == 0 for grad_flag in requires_grad_flags)
# Prepare inputs for calling Python function.
wrapped_args = list(wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace, is_training_mode, arg)
for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, args))
# Prepare inputs for calling Python function.
wrapped_args = list(wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace, is_training_mode, arg)
for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, args))
# Call Python function.
result = backward_function(*wrapped_args)
# Call Python function.
result = backward_function(*wrapped_args)
# Extract results as DLPack tensor list.
wrapped_returned_args = wrap_all_outputs(result)
# Extract results as DLPack tensor list.
wrapped_returned_args = wrap_all_outputs(result)
ctx = wrapped_args[0]
torch_interop_utils.unregister_grad_fn(id(ctx))
ctx = wrapped_args[0]
torch_interop_utils.unregister_grad_fn(id(ctx))
return tuple(wrapped_returned_args)
except Exception as e:
# Flush buffers. Otherwise, calling this from C++ may lose them.
print('Exception happens when running ', backward_function)
sys.stdout.flush()
sys.stderr.flush()
raise wrap_exception(ORTModuleFallbackException, e)
return tuple(wrapped_returned_args)
except Exception as e:
# Flush buffers. Otherwise, calling this from C++ may lose them.
print('Exception happens when running ', backward_function)
sys.stdout.flush()
sys.stderr.flush()
raise wrap_exception(ORTModuleFallbackException, e)

View file

@ -3,16 +3,22 @@
#include <torch/extension.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/functions/accumulate_grad.h>
// In Torch forward run (e.g. THPVariable_apply), ctx of type THPFunction* (which is also a PyObject*)
// is created. The ctx is used to run user-defined forward function and backward function as the first
// parameter. The same time, a cdata of type std::shared_ptr<PyNode> is created, cdata is owned by:
// a). forward run output tensors as grad_fn_ property. (The full hierarchy is: Tensor own
// In Torch forward run (e.g. THPFunction_apply), ctx of type THPFunction* (which is also a PyObject*)
// is created (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L673).
// The ctx is used to run user-defined forward function and backward function as the first
// parameter. The same time, a cdata of type std::shared_ptr<PyNode> is created
// (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L677),
// cdata is owned by:
// a). forward run output tensors as grad_fn_ property. (The full hierarchy is: Tensor owns
// shared_pointer<TensorImpl>; TensorImpl owns std::unique_ptr<AutogradMeta>; AutogradMeta
// manages grad_/grad_fn_/grad_accumulator_. Among them, grad_fn_ is std::shared_ptr<PyNode>,
// the so called gradient function.)
// b). the consumer operator of forward run outputs, will let its own PyNode/Node own the grad_fn_
// (of type std::shared_ptr<PyNode>) of all inputs that require grad.
// e.g, the so called gradient function.)
// https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/variable.h#L194
// b). the consumer operator of forward run outputs, will let its own PyNode/Node (gradident function)
// owns the grad_fn_ (of type std::shared_ptr<PyNode>) of all inputs that require grad.
// https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/function.h#L263
// BUT, if we run torch computation within PythonOp, b) is lost. SO, for some cases, where forward outputs
// are not used and freed before backward function runs, the grad_fn_ (std::shared_ptr<PyNode>) references
// in a) will be released. Without b)'s reference, grad_fn_ release PyNode as reference count reach 0;
@ -55,6 +61,45 @@ class PyNodeSharedPointerPool {
};
void clear_grad_fns_for_next_edges(at::Tensor target, std::vector<at::Tensor> saved_tensors) {
// For leaf tensor, there will be a AccumulateGrad (gradident function) created, which owns a
// reference to the tensor.
// For any user saved tensors (with save_for_backward), if the tensor is leaf, we put the map
// {AccumulateGrad*, Tensor*} into grad_fn_to_tensor_map.
std::unordered_map<torch::autograd::Node*, at::Tensor*> grad_fn_to_tensor_map;
for (auto& t: saved_tensors) {
auto grad_fn = t.grad_fn();
if (!grad_fn) {
grad_fn = torch::autograd::impl::try_get_grad_accumulator(t);
if (grad_fn) {
TORCH_CHECK(grad_fn_to_tensor_map.find(grad_fn.get()) == grad_fn_to_tensor_map.end(),
"found AccumulateGrad* is used by more than one tensors.");
grad_fn_to_tensor_map.insert({grad_fn.get(), &t});
}
}
}
const auto& gradient_func_sptr = target.grad_fn();
for (auto& edge : gradient_func_sptr->next_edges()) {
torch::autograd::Node* node_func = edge.function.get();
// If we find the next gradient function is AccumulateGrad, we will check whether its owned
// tensors is in ctx.save_tensors or not. If yes, we skip it; otherwise, we clean the edge, which
// will release the AccumulateGrad function.
if (dynamic_cast<torch::autograd::AccumulateGrad*>(node_func)) {
if (grad_fn_to_tensor_map.find(node_func) != grad_fn_to_tensor_map.end()) {
// skip the edges that connect to saved_tensors. Because when unpack ctx.saved_tensors (using input, = ctx.saved_tensors) in backward,
// there is such a check : if the saved tensor is a leaf and requires grad, it it should have grad accumulator.
// If we clean the edge, then an exception "RuntimeError: No grad accumulator for a saved leaf!" will be thrown
TORCH_WARN("Find a AccumulateGrad node, but skip it because the owned tensor is in saved_tensors.");
continue;
} else {
TORCH_WARN("Find a AccumulateGrad node, and planned to clean the edge to it.");
edge.function.reset();
}
}
}
}
void register_grad_fn(size_t ctx_address, at::Tensor target)
{
torch::autograd::AutogradMeta* autograd_meta = torch::autograd::impl::get_autograd_meta(target);
@ -69,4 +114,5 @@ void unregister_grad_fn(size_t ctx_address)
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("register_grad_fn", &register_grad_fn, "increase grad_fn shared pointer reference.");
m.def("unregister_grad_fn", &unregister_grad_fn, "release grad_fn shared pointer referece.");
m.def("clear_grad_fns_for_next_edges", &clear_grad_fns_for_next_edges, "remove reference on next edges' gradident funtions.");
}