diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py index f45837a482..208fd16e0e 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py @@ -81,6 +81,9 @@ def call_python_forward_function( def register_context(result): # Search for context among all outputs. ctx = None + # All forward outputs of torch.autograd.Function shared a same gradient function pointer, + # so here we just get the first tensor having grad_fn attribute. + # (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/custom_function.cpp#L267) first_tensor_output = None for arg in result: if not isinstance(arg, torch.Tensor) or not hasattr(arg, 'grad_fn'): @@ -93,6 +96,22 @@ def call_python_forward_function( if training_mode_flag: # Must extract one valid context from result tensors. assert ctx is not None + + # FORWARD BACKWARD FUNCTION CONNECTIONS + # input_1 (leaf, constructed by from_dlpack) <----reference---- AccumulateGrad gradient function + # ↓ ↑ + # autograd.Function apply() ------------> autograd.Function backward() + # ↓ | ↑ + # output_1, output_2 --- shared_ptr --- ↑ + # ↓ previous gradient function + + # We remove the edges starting between current autograd.Function's gradient function and + # it's input's gradient function (e.g. AccumulateGrad gradient function), then + # AccumulateGrad gradient function will be destroyed, releasing the reference to input_1 + # (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/functions/accumulate_grad.cpp#L21). + # The next edges are stored in Node, with which we can get next gradient function. + # https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/function.h#L527 + torch_interop_utils.clear_grad_fns_for_next_edges(first_tensor_output, ctx.saved_tensors) torch_interop_utils.register_grad_fn(id(ctx), first_tensor_output) else: # Context must not present under non-training mode. @@ -158,36 +177,37 @@ def call_python_backward_function( inplace: indicates if args can be modified inside the custom function. args: inputs to "backward_function". ''' - def wrap_all_outputs(result): - if isinstance(result, torch.Tensor): - return [to_dlpack(result)] - elif isinstance(result, tuple) or isinstance(result, list): - return [to_dlpack(value) if value is not None else None for value in result] - else: - raise wrap_exception(ORTModuleIOError, - TypeError(f'ORTModule does not support the following model output type {type(result)}.')) + with torch.no_grad(): + def wrap_all_outputs(result): + if isinstance(result, torch.Tensor): + return [to_dlpack(result)] + elif isinstance(result, tuple) or isinstance(result, list): + return [to_dlpack(value) if value is not None else None for value in result] + else: + raise wrap_exception(ORTModuleIOError, + TypeError(f'ORTModule does not support the following model output type {type(result)}.')) - try: - # Backward inputs should not require gradients. - assert all(grad_flag == 0 for grad_flag in requires_grad_flags) + try: + # Backward inputs should not require gradients. + assert all(grad_flag == 0 for grad_flag in requires_grad_flags) - # Prepare inputs for calling Python function. - wrapped_args = list(wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace, is_training_mode, arg) - for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, args)) + # Prepare inputs for calling Python function. + wrapped_args = list(wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace, is_training_mode, arg) + for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, args)) - # Call Python function. - result = backward_function(*wrapped_args) + # Call Python function. + result = backward_function(*wrapped_args) - # Extract results as DLPack tensor list. - wrapped_returned_args = wrap_all_outputs(result) + # Extract results as DLPack tensor list. + wrapped_returned_args = wrap_all_outputs(result) - ctx = wrapped_args[0] - torch_interop_utils.unregister_grad_fn(id(ctx)) + ctx = wrapped_args[0] + torch_interop_utils.unregister_grad_fn(id(ctx)) - return tuple(wrapped_returned_args) - except Exception as e: - # Flush buffers. Otherwise, calling this from C++ may lose them. - print('Exception happens when running ', backward_function) - sys.stdout.flush() - sys.stderr.flush() - raise wrap_exception(ORTModuleFallbackException, e) + return tuple(wrapped_returned_args) + except Exception as e: + # Flush buffers. Otherwise, calling this from C++ may lose them. + print('Exception happens when running ', backward_function) + sys.stdout.flush() + sys.stderr.flush() + raise wrap_exception(ORTModuleFallbackException, e) diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/torch_interop_utils/torch_interop_utils.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/torch_interop_utils/torch_interop_utils.cc index ffabdbe211..efc84e11c1 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/torch_interop_utils/torch_interop_utils.cc +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/torch_interop_utils/torch_interop_utils.cc @@ -3,16 +3,22 @@ #include #include #include +#include -// In Torch forward run (e.g. THPVariable_apply), ctx of type THPFunction* (which is also a PyObject*) -// is created. The ctx is used to run user-defined forward function and backward function as the first -// parameter. The same time, a cdata of type std::shared_ptr is created, cdata is owned by: -// a). forward run output tensors as grad_fn_ property. (The full hierarchy is: Tensor own +// In Torch forward run (e.g. THPFunction_apply), ctx of type THPFunction* (which is also a PyObject*) +// is created (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L673). +// The ctx is used to run user-defined forward function and backward function as the first +// parameter. The same time, a cdata of type std::shared_ptr is created +// (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L677), +// cdata is owned by: +// a). forward run output tensors as grad_fn_ property. (The full hierarchy is: Tensor owns // shared_pointer; TensorImpl owns std::unique_ptr; AutogradMeta // manages grad_/grad_fn_/grad_accumulator_. Among them, grad_fn_ is std::shared_ptr, -// the so called gradient function.) -// b). the consumer operator of forward run outputs, will let its own PyNode/Node own the grad_fn_ -// (of type std::shared_ptr) of all inputs that require grad. +// e.g, the so called gradient function.) +// https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/variable.h#L194 +// b). the consumer operator of forward run outputs, will let its own PyNode/Node (gradident function) +// owns the grad_fn_ (of type std::shared_ptr) of all inputs that require grad. +// https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/function.h#L263 // BUT, if we run torch computation within PythonOp, b) is lost. SO, for some cases, where forward outputs // are not used and freed before backward function runs, the grad_fn_ (std::shared_ptr) references // in a) will be released. Without b)'s reference, grad_fn_ release PyNode as reference count reach 0; @@ -55,6 +61,45 @@ class PyNodeSharedPointerPool { }; +void clear_grad_fns_for_next_edges(at::Tensor target, std::vector saved_tensors) { + // For leaf tensor, there will be a AccumulateGrad (gradident function) created, which owns a + // reference to the tensor. + // For any user saved tensors (with save_for_backward), if the tensor is leaf, we put the map + // {AccumulateGrad*, Tensor*} into grad_fn_to_tensor_map. + std::unordered_map grad_fn_to_tensor_map; + for (auto& t: saved_tensors) { + auto grad_fn = t.grad_fn(); + if (!grad_fn) { + grad_fn = torch::autograd::impl::try_get_grad_accumulator(t); + if (grad_fn) { + TORCH_CHECK(grad_fn_to_tensor_map.find(grad_fn.get()) == grad_fn_to_tensor_map.end(), + "found AccumulateGrad* is used by more than one tensors."); + grad_fn_to_tensor_map.insert({grad_fn.get(), &t}); + } + } + } + + const auto& gradient_func_sptr = target.grad_fn(); + for (auto& edge : gradient_func_sptr->next_edges()) { + torch::autograd::Node* node_func = edge.function.get(); + // If we find the next gradient function is AccumulateGrad, we will check whether its owned + // tensors is in ctx.save_tensors or not. If yes, we skip it; otherwise, we clean the edge, which + // will release the AccumulateGrad function. + if (dynamic_cast(node_func)) { + if (grad_fn_to_tensor_map.find(node_func) != grad_fn_to_tensor_map.end()) { + // skip the edges that connect to saved_tensors. Because when unpack ctx.saved_tensors (using input, = ctx.saved_tensors) in backward, + // there is such a check : if the saved tensor is a leaf and requires grad, it it should have grad accumulator. + // If we clean the edge, then an exception "RuntimeError: No grad accumulator for a saved leaf!" will be thrown + TORCH_WARN("Find a AccumulateGrad node, but skip it because the owned tensor is in saved_tensors."); + continue; + } else { + TORCH_WARN("Find a AccumulateGrad node, and planned to clean the edge to it."); + edge.function.reset(); + } + } + } +} + void register_grad_fn(size_t ctx_address, at::Tensor target) { torch::autograd::AutogradMeta* autograd_meta = torch::autograd::impl::get_autograd_meta(target); @@ -69,4 +114,5 @@ void unregister_grad_fn(size_t ctx_address) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("register_grad_fn", ®ister_grad_fn, "increase grad_fn shared pointer reference."); m.def("unregister_grad_fn", &unregister_grad_fn, "release grad_fn shared pointer referece."); + m.def("clear_grad_fns_for_next_edges", &clear_grad_fns_for_next_edges, "remove reference on next edges' gradident funtions."); }