From 3eb08d4dc76f0c81a16a474b7e1e16e7a5e86067 Mon Sep 17 00:00:00 2001 From: pengwa Date: Wed, 1 Sep 2021 09:29:26 +0800 Subject: [PATCH] custom autograd func memory (#8901) * remove PythonOpGrad control dependency && avoid segement fault * comment alignment * fix bugs --- cmake/onnxruntime_python.cmake | 7 ++ .../core/graph/gradient_builder.cc | 5 +- .../_custom_autograd_function_runner.py | 15 ++-- .../ortmodule/torch_cpp_extensions/install.py | 10 +++ .../torch_interop_utils/setup.py | 15 ++++ .../torch_interop_utils.cc | 72 +++++++++++++++++++ .../orttraining_test_ortmodule_autograd.py | 62 ++++++++++++++++ setup.py | 2 + 8 files changed, 181 insertions(+), 7 deletions(-) create mode 100644 orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/torch_interop_utils/setup.py create mode 100644 orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/torch_interop_utils/torch_interop_utils.cc diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index f25dc7e859..098502a36f 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -268,6 +268,9 @@ if (onnxruntime_ENABLE_TRAINING) file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_aten_op_executor_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/aten_op_executor/*" ) + file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_torch_interop_utils_srcs CONFIGURE_DEPENDS + "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/torch_interop_utils/*" + ) file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_torch_gpu_allocator_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator/*" ) @@ -504,6 +507,7 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/experimental/json_config COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/aten_op_executor + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_interop_utils COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_capi_training_srcs} @@ -532,6 +536,9 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_ortmodule_torch_cpp_ext_aten_op_executor_srcs} $/onnxruntime/training/ortmodule/torch_cpp_extensions/aten_op_executor/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_ortmodule_torch_cpp_ext_torch_interop_utils_srcs} + $/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_interop_utils/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_ortmodule_torch_cpp_ext_torch_gpu_allocator_srcs} $/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator/ diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 4af0dbc608..c79829bf6e 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1785,10 +1785,9 @@ IMPLEMENT_GRADIENT_BUILDER(GetPythonOpGradient) { } // Also connect forward outputs to PythonOpGrad for random segement fault issues. - // Todo (pengwa): we should investigate whether we could avoid those outputs that are not used - // in backward computation. + // Todo (pengwa): remove the control dependency from PythonOpGrad schema. for (int i = 1; i < GetSrcNodeOutputSize(); ++i) { - input_args.push_back(O(i)); + input_args.push_back(ArgDef()); } // src_attrs["input_requires_grads"] stores all inputs's requires_grad attributes, diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py index 92497441fb..f45837a482 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py @@ -10,6 +10,7 @@ from torch.utils.dlpack import from_dlpack, to_dlpack from ._fallback import _FallbackManager, ORTModuleFallbackException, ORTModuleIOError, wrap_exception +from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils def wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace_flag, training_mode_flag, arg): ''' @@ -76,30 +77,33 @@ def call_python_forward_function( return arg def wrap_all_outputs(result, training_mode_flag): - def extract_context(result): + # This is mainly to hold grad_fn references by registering it into our PyNodeSharedPointerPool. + def register_context(result): # Search for context among all outputs. ctx = None + first_tensor_output = None for arg in result: if not isinstance(arg, torch.Tensor) or not hasattr(arg, 'grad_fn'): continue # Use the first context we see because all of arg's # share the same one. ctx = arg.grad_fn + first_tensor_output = arg break if training_mode_flag: # Must extract one valid context from result tensors. assert ctx is not None + torch_interop_utils.register_grad_fn(id(ctx), first_tensor_output) else: # Context must not present under non-training mode. assert ctx is None - return ctx if isinstance(result, torch.Tensor): - ctx = extract_context([result]) + ctx = register_context([result]) return [ctx, to_dlpack(result)] elif isinstance(result, tuple) or isinstance(result, list): - ctx = extract_context(result) + ctx = register_context(result) wrapped = [ctx] wrapped.extend(list(to_dlpack(value) if value is not None else None for value in result)) # Inside the returned list, first element is context and the rest @@ -177,6 +181,9 @@ def call_python_backward_function( # Extract results as DLPack tensor list. wrapped_returned_args = wrap_all_outputs(result) + ctx = wrapped_args[0] + torch_interop_utils.unregister_grad_fn(id(ctx)) + return tuple(wrapped_returned_args) except Exception as e: # Flush buffers. Otherwise, calling this from C++ may lose them. diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py index 7810ba5108..3818059d9e 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py @@ -52,6 +52,16 @@ def build_torch_cpp_extensions(): print('There was an error compiling "aten_op_executor" PyTorch CPP extension') sys.exit(ret_code) + setup_script = os.path.join(cpp_ext_dir, + 'torch_interop_utils', + 'setup.py') + ret_code = subprocess.call(f"{sys.executable} {setup_script} build", + cwd=cpp_ext_dir, + shell=True) + if ret_code != 0: + print('There was an error compiling "torch_interop_utils" PyTorch CPP extension') + sys.exit(ret_code) + ############################################################################ # Copy Pytorch CPP Extensions to the local onnxruntime package folder ############################################################################ diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/torch_interop_utils/setup.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/torch_interop_utils/setup.py new file mode 100644 index 0000000000..cb9bff33c0 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/torch_interop_utils/setup.py @@ -0,0 +1,15 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import os +from setuptools import setup, Extension +from torch.utils import cpp_extension + +filename = os.path.join(os.path.dirname(__file__), + 'torch_interop_utils.cc') +setup(name='torch_interop_utils', + ext_modules=[cpp_extension.CppExtension(name='torch_interop_utils', + sources=[filename])], + cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/torch_interop_utils/torch_interop_utils.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/torch_interop_utils/torch_interop_utils.cc new file mode 100644 index 0000000000..ffabdbe211 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/torch_interop_utils/torch_interop_utils.cc @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include +#include +#include + +// In Torch forward run (e.g. THPVariable_apply), ctx of type THPFunction* (which is also a PyObject*) +// is created. The ctx is used to run user-defined forward function and backward function as the first +// parameter. The same time, a cdata of type std::shared_ptr is created, cdata is owned by: +// a). forward run output tensors as grad_fn_ property. (The full hierarchy is: Tensor own +// shared_pointer; TensorImpl owns std::unique_ptr; AutogradMeta +// manages grad_/grad_fn_/grad_accumulator_. Among them, grad_fn_ is std::shared_ptr, +// the so called gradient function.) +// b). the consumer operator of forward run outputs, will let its own PyNode/Node own the grad_fn_ +// (of type std::shared_ptr) of all inputs that require grad. +// BUT, if we run torch computation within PythonOp, b) is lost. SO, for some cases, where forward outputs +// are not used and freed before backward function runs, the grad_fn_ (std::shared_ptr) references +// in a) will be released. Without b)'s reference, grad_fn_ release PyNode as reference count reach 0; +// Then when PythonOpGrad runs, segment fault. +// +// So we add b)'s reference in this Pool when forward run returns; dereference from this Pool when backward +// completes, then ~PyNode() is called, which subsquently calls ~THPFunction() destorying ctx. +class PyNodeSharedPointerPool { + public: + static PyNodeSharedPointerPool& GetInstance() { + static PyNodeSharedPointerPool pool; + return pool; + }; + + void RegisterGradFunc(const size_t& ctx_address, torch::autograd::AutogradMeta* autograd_meta){ + auto it = grad_fns_.find(ctx_address); + TORCH_CHECK(it == grad_fns_.end(), "should not register grad_fn twice for ctx ", ctx_address); + + // Add new entry if key hasn't been registered. + grad_fns_.emplace(ctx_address, std::move(autograd_meta->grad_fn_)); + }; + + void UnRegisterGradFunc(const size_t& ctx_address){ + auto it = grad_fns_.find(ctx_address); + TORCH_CHECK(it != grad_fns_.end(), "fail to find grad_fn for ctx ", ctx_address); + + grad_fns_.erase(ctx_address); + }; + + private: + PyNodeSharedPointerPool(){}; + ~PyNodeSharedPointerPool(){}; + + PyNodeSharedPointerPool(const PyNodeSharedPointerPool&) = delete; + PyNodeSharedPointerPool& operator=(const PyNodeSharedPointerPool&) = delete; + PyNodeSharedPointerPool(PyNodeSharedPointerPool&&) = delete; + PyNodeSharedPointerPool& operator=(PyNodeSharedPointerPool&&) = delete; + + std::unordered_map> grad_fns_; +}; + + +void register_grad_fn(size_t ctx_address, at::Tensor target) +{ + torch::autograd::AutogradMeta* autograd_meta = torch::autograd::impl::get_autograd_meta(target); + PyNodeSharedPointerPool::GetInstance().RegisterGradFunc(ctx_address, autograd_meta); +} + +void unregister_grad_fn(size_t ctx_address) +{ + PyNodeSharedPointerPool::GetInstance().UnRegisterGradFunc(ctx_address); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("register_grad_fn", ®ister_grad_fn, "increase grad_fn shared pointer reference."); + m.def("unregister_grad_fn", &unregister_grad_fn, "release grad_fn shared pointer referece."); +} diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py index 8b5ed034bb..a71fd330f7 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py @@ -74,6 +74,68 @@ def test_GeLU(): run_training_test_and_compare(model_builder, input_generator, label_input) +def test_GeLU_custom_func_rets_not_as_module_output(): + @torch.jit.script + def bias_gelu(bias, y): + x = bias + y + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + + @torch.jit.script + def bias_gelu_backward(g, bias, y): + x = bias + y + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return ff*g + + class GeLUFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, bias): + ctx.save_for_backward(input, bias) + return bias_gelu(bias, input) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + tmp = bias_gelu_backward(grad_output, bias, input) + return tmp, tmp + + class GeLUModel(torch.nn.Module): + def __init__(self, output_size): + super(GeLUModel, self).__init__() + self.relu = GeLUFunction.apply + self.bias = Parameter(torch.empty( + output_size, + device=torch.cuda.current_device(), + dtype=torch.float)) + + with torch.no_grad(): + self.bias.uniform_() + + def forward(self, model_input): + out = self.relu(model_input, self.bias) + # add * 9 by intention to make custom function's output + # NOT as module outputs (which are consumed by subsquent computations). + # This aims to trigger a GC for "out", saying, out is released, + # the underlying std::shared still have other references. + # Otherwise, a segementfault will be triggered. + out = out * 9 + return out + + output_size = 1024 + + def model_builder(): + return GeLUModel(output_size) + + def input_generator(): + return torch.randn(output_size, dtype=torch.float) + + # generate a label that have same shape as forward output. + label_input = torch.ones([output_size]) + + run_training_test_and_compare(model_builder, input_generator, label_input) + + def test_MegatronF(): # MegatronGFunction is tested in distributed test files. class MegatronFFunction(torch.autograd.Function): diff --git a/setup.py b/setup.py index 1981632f48..54fa3ab310 100644 --- a/setup.py +++ b/setup.py @@ -323,8 +323,10 @@ if enable_training: 'onnxruntime.training.ortmodule.experimental.json_config', 'onnxruntime.training.ortmodule.torch_cpp_extensions', 'onnxruntime.training.ortmodule.torch_cpp_extensions.aten_op_executor', + 'onnxruntime.training.ortmodule.torch_cpp_extensions.torch_interop_utils', 'onnxruntime.training.ortmodule.torch_cpp_extensions.torch_gpu_allocator']) package_data['onnxruntime.training.ortmodule.torch_cpp_extensions.aten_op_executor'] = ['*.cc'] + package_data['onnxruntime.training.ortmodule.torch_cpp_extensions.torch_interop_utils'] = ['*.cc'] package_data['onnxruntime.training.ortmodule.torch_cpp_extensions.torch_gpu_allocator'] = ['*.cc'] requirements_file = "requirements-training.txt" # with training, we want to follow this naming convention: