custom autograd func memory (#8901)

* remove PythonOpGrad control dependency && avoid segement fault

* comment alignment

* fix bugs
This commit is contained in:
pengwa 2021-09-01 09:29:26 +08:00 committed by GitHub
parent feb747173e
commit 3eb08d4dc7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 181 additions and 7 deletions

View file

@ -268,6 +268,9 @@ if (onnxruntime_ENABLE_TRAINING)
file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_aten_op_executor_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/aten_op_executor/*"
)
file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_torch_interop_utils_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/torch_interop_utils/*"
)
file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_torch_gpu_allocator_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator/*"
)
@ -504,6 +507,7 @@ if (onnxruntime_ENABLE_TRAINING)
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/experimental/json_config
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/aten_op_executor
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_interop_utils
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_capi_training_srcs}
@ -532,6 +536,9 @@ if (onnxruntime_ENABLE_TRAINING)
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_ortmodule_torch_cpp_ext_aten_op_executor_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/aten_op_executor/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_ortmodule_torch_cpp_ext_torch_interop_utils_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_interop_utils/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_ortmodule_torch_cpp_ext_torch_gpu_allocator_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator/

View file

@ -1785,10 +1785,9 @@ IMPLEMENT_GRADIENT_BUILDER(GetPythonOpGradient) {
}
// Also connect forward outputs to PythonOpGrad for random segement fault issues.
// Todo (pengwa): we should investigate whether we could avoid those outputs that are not used
// in backward computation.
// Todo (pengwa): remove the control dependency from PythonOpGrad schema.
for (int i = 1; i < GetSrcNodeOutputSize(); ++i) {
input_args.push_back(O(i));
input_args.push_back(ArgDef());
}
// src_attrs["input_requires_grads"] stores all inputs's requires_grad attributes,

View file

@ -10,6 +10,7 @@ from torch.utils.dlpack import from_dlpack, to_dlpack
from ._fallback import _FallbackManager, ORTModuleFallbackException, ORTModuleIOError, wrap_exception
from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils
def wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace_flag, training_mode_flag, arg):
'''
@ -76,30 +77,33 @@ def call_python_forward_function(
return arg
def wrap_all_outputs(result, training_mode_flag):
def extract_context(result):
# This is mainly to hold grad_fn references by registering it into our PyNodeSharedPointerPool.
def register_context(result):
# Search for context among all outputs.
ctx = None
first_tensor_output = None
for arg in result:
if not isinstance(arg, torch.Tensor) or not hasattr(arg, 'grad_fn'):
continue
# Use the first context we see because all of arg's
# share the same one.
ctx = arg.grad_fn
first_tensor_output = arg
break
if training_mode_flag:
# Must extract one valid context from result tensors.
assert ctx is not None
torch_interop_utils.register_grad_fn(id(ctx), first_tensor_output)
else:
# Context must not present under non-training mode.
assert ctx is None
return ctx
if isinstance(result, torch.Tensor):
ctx = extract_context([result])
ctx = register_context([result])
return [ctx, to_dlpack(result)]
elif isinstance(result, tuple) or isinstance(result, list):
ctx = extract_context(result)
ctx = register_context(result)
wrapped = [ctx]
wrapped.extend(list(to_dlpack(value) if value is not None else None for value in result))
# Inside the returned list, first element is context and the rest
@ -177,6 +181,9 @@ def call_python_backward_function(
# Extract results as DLPack tensor list.
wrapped_returned_args = wrap_all_outputs(result)
ctx = wrapped_args[0]
torch_interop_utils.unregister_grad_fn(id(ctx))
return tuple(wrapped_returned_args)
except Exception as e:
# Flush buffers. Otherwise, calling this from C++ may lose them.

View file

@ -52,6 +52,16 @@ def build_torch_cpp_extensions():
print('There was an error compiling "aten_op_executor" PyTorch CPP extension')
sys.exit(ret_code)
setup_script = os.path.join(cpp_ext_dir,
'torch_interop_utils',
'setup.py')
ret_code = subprocess.call(f"{sys.executable} {setup_script} build",
cwd=cpp_ext_dir,
shell=True)
if ret_code != 0:
print('There was an error compiling "torch_interop_utils" PyTorch CPP extension')
sys.exit(ret_code)
############################################################################
# Copy Pytorch CPP Extensions to the local onnxruntime package folder
############################################################################

View file

@ -0,0 +1,15 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os
from setuptools import setup, Extension
from torch.utils import cpp_extension
filename = os.path.join(os.path.dirname(__file__),
'torch_interop_utils.cc')
setup(name='torch_interop_utils',
ext_modules=[cpp_extension.CppExtension(name='torch_interop_utils',
sources=[filename])],
cmdclass={'build_ext': cpp_extension.BuildExtension})

View file

@ -0,0 +1,72 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <torch/extension.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/variable.h>
// In Torch forward run (e.g. THPVariable_apply), ctx of type THPFunction* (which is also a PyObject*)
// is created. The ctx is used to run user-defined forward function and backward function as the first
// parameter. The same time, a cdata of type std::shared_ptr<PyNode> is created, cdata is owned by:
// a). forward run output tensors as grad_fn_ property. (The full hierarchy is: Tensor own
// shared_pointer<TensorImpl>; TensorImpl owns std::unique_ptr<AutogradMeta>; AutogradMeta
// manages grad_/grad_fn_/grad_accumulator_. Among them, grad_fn_ is std::shared_ptr<PyNode>,
// the so called gradient function.)
// b). the consumer operator of forward run outputs, will let its own PyNode/Node own the grad_fn_
// (of type std::shared_ptr<PyNode>) of all inputs that require grad.
// BUT, if we run torch computation within PythonOp, b) is lost. SO, for some cases, where forward outputs
// are not used and freed before backward function runs, the grad_fn_ (std::shared_ptr<PyNode>) references
// in a) will be released. Without b)'s reference, grad_fn_ release PyNode as reference count reach 0;
// Then when PythonOpGrad runs, segment fault.
//
// So we add b)'s reference in this Pool when forward run returns; dereference from this Pool when backward
// completes, then ~PyNode() is called, which subsquently calls ~THPFunction() destorying ctx.
class PyNodeSharedPointerPool {
public:
static PyNodeSharedPointerPool& GetInstance() {
static PyNodeSharedPointerPool pool;
return pool;
};
void RegisterGradFunc(const size_t& ctx_address, torch::autograd::AutogradMeta* autograd_meta){
auto it = grad_fns_.find(ctx_address);
TORCH_CHECK(it == grad_fns_.end(), "should not register grad_fn twice for ctx ", ctx_address);
// Add new entry if key hasn't been registered.
grad_fns_.emplace(ctx_address, std::move(autograd_meta->grad_fn_));
};
void UnRegisterGradFunc(const size_t& ctx_address){
auto it = grad_fns_.find(ctx_address);
TORCH_CHECK(it != grad_fns_.end(), "fail to find grad_fn for ctx ", ctx_address);
grad_fns_.erase(ctx_address);
};
private:
PyNodeSharedPointerPool(){};
~PyNodeSharedPointerPool(){};
PyNodeSharedPointerPool(const PyNodeSharedPointerPool&) = delete;
PyNodeSharedPointerPool& operator=(const PyNodeSharedPointerPool&) = delete;
PyNodeSharedPointerPool(PyNodeSharedPointerPool&&) = delete;
PyNodeSharedPointerPool& operator=(PyNodeSharedPointerPool&&) = delete;
std::unordered_map<size_t, std::shared_ptr<torch::autograd::Node>> grad_fns_;
};
void register_grad_fn(size_t ctx_address, at::Tensor target)
{
torch::autograd::AutogradMeta* autograd_meta = torch::autograd::impl::get_autograd_meta(target);
PyNodeSharedPointerPool::GetInstance().RegisterGradFunc(ctx_address, autograd_meta);
}
void unregister_grad_fn(size_t ctx_address)
{
PyNodeSharedPointerPool::GetInstance().UnRegisterGradFunc(ctx_address);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("register_grad_fn", &register_grad_fn, "increase grad_fn shared pointer reference.");
m.def("unregister_grad_fn", &unregister_grad_fn, "release grad_fn shared pointer referece.");
}

View file

@ -74,6 +74,68 @@ def test_GeLU():
run_training_test_and_compare(model_builder, input_generator, label_input)
def test_GeLU_custom_func_rets_not_as_module_output():
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
@torch.jit.script
def bias_gelu_backward(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 +
0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff*g
class GeLUFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_backward(grad_output, bias, input)
return tmp, tmp
class GeLUModel(torch.nn.Module):
def __init__(self, output_size):
super(GeLUModel, self).__init__()
self.relu = GeLUFunction.apply
self.bias = Parameter(torch.empty(
output_size,
device=torch.cuda.current_device(),
dtype=torch.float))
with torch.no_grad():
self.bias.uniform_()
def forward(self, model_input):
out = self.relu(model_input, self.bias)
# add * 9 by intention to make custom function's output
# NOT as module outputs (which are consumed by subsquent computations).
# This aims to trigger a GC for "out", saying, out is released,
# the underlying std::shared<PyNode> still have other references.
# Otherwise, a segementfault will be triggered.
out = out * 9
return out
output_size = 1024
def model_builder():
return GeLUModel(output_size)
def input_generator():
return torch.randn(output_size, dtype=torch.float)
# generate a label that have same shape as forward output.
label_input = torch.ones([output_size])
run_training_test_and_compare(model_builder, input_generator, label_input)
def test_MegatronF():
# MegatronGFunction is tested in distributed test files.
class MegatronFFunction(torch.autograd.Function):

View file

@ -323,8 +323,10 @@ if enable_training:
'onnxruntime.training.ortmodule.experimental.json_config',
'onnxruntime.training.ortmodule.torch_cpp_extensions',
'onnxruntime.training.ortmodule.torch_cpp_extensions.aten_op_executor',
'onnxruntime.training.ortmodule.torch_cpp_extensions.torch_interop_utils',
'onnxruntime.training.ortmodule.torch_cpp_extensions.torch_gpu_allocator'])
package_data['onnxruntime.training.ortmodule.torch_cpp_extensions.aten_op_executor'] = ['*.cc']
package_data['onnxruntime.training.ortmodule.torch_cpp_extensions.torch_interop_utils'] = ['*.cc']
package_data['onnxruntime.training.ortmodule.torch_cpp_extensions.torch_gpu_allocator'] = ['*.cc']
requirements_file = "requirements-training.txt"
# with training, we want to follow this naming convention: