mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
custom autograd func memory (#8901)
* remove PythonOpGrad control dependency && avoid segement fault * comment alignment * fix bugs
This commit is contained in:
parent
feb747173e
commit
3eb08d4dc7
8 changed files with 181 additions and 7 deletions
|
|
@ -268,6 +268,9 @@ if (onnxruntime_ENABLE_TRAINING)
|
|||
file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_aten_op_executor_srcs CONFIGURE_DEPENDS
|
||||
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/aten_op_executor/*"
|
||||
)
|
||||
file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_torch_interop_utils_srcs CONFIGURE_DEPENDS
|
||||
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/torch_interop_utils/*"
|
||||
)
|
||||
file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_torch_gpu_allocator_srcs CONFIGURE_DEPENDS
|
||||
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator/*"
|
||||
)
|
||||
|
|
@ -504,6 +507,7 @@ if (onnxruntime_ENABLE_TRAINING)
|
|||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/experimental/json_config
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/aten_op_executor
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_interop_utils
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_capi_training_srcs}
|
||||
|
|
@ -532,6 +536,9 @@ if (onnxruntime_ENABLE_TRAINING)
|
|||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_ortmodule_torch_cpp_ext_aten_op_executor_srcs}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/aten_op_executor/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_ortmodule_torch_cpp_ext_torch_interop_utils_srcs}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_interop_utils/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_ortmodule_torch_cpp_ext_torch_gpu_allocator_srcs}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator/
|
||||
|
|
|
|||
|
|
@ -1785,10 +1785,9 @@ IMPLEMENT_GRADIENT_BUILDER(GetPythonOpGradient) {
|
|||
}
|
||||
|
||||
// Also connect forward outputs to PythonOpGrad for random segement fault issues.
|
||||
// Todo (pengwa): we should investigate whether we could avoid those outputs that are not used
|
||||
// in backward computation.
|
||||
// Todo (pengwa): remove the control dependency from PythonOpGrad schema.
|
||||
for (int i = 1; i < GetSrcNodeOutputSize(); ++i) {
|
||||
input_args.push_back(O(i));
|
||||
input_args.push_back(ArgDef());
|
||||
}
|
||||
|
||||
// src_attrs["input_requires_grads"] stores all inputs's requires_grad attributes,
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from torch.utils.dlpack import from_dlpack, to_dlpack
|
|||
|
||||
from ._fallback import _FallbackManager, ORTModuleFallbackException, ORTModuleIOError, wrap_exception
|
||||
|
||||
from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils
|
||||
|
||||
def wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace_flag, training_mode_flag, arg):
|
||||
'''
|
||||
|
|
@ -76,30 +77,33 @@ def call_python_forward_function(
|
|||
return arg
|
||||
|
||||
def wrap_all_outputs(result, training_mode_flag):
|
||||
def extract_context(result):
|
||||
# This is mainly to hold grad_fn references by registering it into our PyNodeSharedPointerPool.
|
||||
def register_context(result):
|
||||
# Search for context among all outputs.
|
||||
ctx = None
|
||||
first_tensor_output = None
|
||||
for arg in result:
|
||||
if not isinstance(arg, torch.Tensor) or not hasattr(arg, 'grad_fn'):
|
||||
continue
|
||||
# Use the first context we see because all of arg's
|
||||
# share the same one.
|
||||
ctx = arg.grad_fn
|
||||
first_tensor_output = arg
|
||||
break
|
||||
if training_mode_flag:
|
||||
# Must extract one valid context from result tensors.
|
||||
assert ctx is not None
|
||||
torch_interop_utils.register_grad_fn(id(ctx), first_tensor_output)
|
||||
else:
|
||||
# Context must not present under non-training mode.
|
||||
assert ctx is None
|
||||
|
||||
return ctx
|
||||
|
||||
if isinstance(result, torch.Tensor):
|
||||
ctx = extract_context([result])
|
||||
ctx = register_context([result])
|
||||
return [ctx, to_dlpack(result)]
|
||||
elif isinstance(result, tuple) or isinstance(result, list):
|
||||
ctx = extract_context(result)
|
||||
ctx = register_context(result)
|
||||
wrapped = [ctx]
|
||||
wrapped.extend(list(to_dlpack(value) if value is not None else None for value in result))
|
||||
# Inside the returned list, first element is context and the rest
|
||||
|
|
@ -177,6 +181,9 @@ def call_python_backward_function(
|
|||
# Extract results as DLPack tensor list.
|
||||
wrapped_returned_args = wrap_all_outputs(result)
|
||||
|
||||
ctx = wrapped_args[0]
|
||||
torch_interop_utils.unregister_grad_fn(id(ctx))
|
||||
|
||||
return tuple(wrapped_returned_args)
|
||||
except Exception as e:
|
||||
# Flush buffers. Otherwise, calling this from C++ may lose them.
|
||||
|
|
|
|||
|
|
@ -52,6 +52,16 @@ def build_torch_cpp_extensions():
|
|||
print('There was an error compiling "aten_op_executor" PyTorch CPP extension')
|
||||
sys.exit(ret_code)
|
||||
|
||||
setup_script = os.path.join(cpp_ext_dir,
|
||||
'torch_interop_utils',
|
||||
'setup.py')
|
||||
ret_code = subprocess.call(f"{sys.executable} {setup_script} build",
|
||||
cwd=cpp_ext_dir,
|
||||
shell=True)
|
||||
if ret_code != 0:
|
||||
print('There was an error compiling "torch_interop_utils" PyTorch CPP extension')
|
||||
sys.exit(ret_code)
|
||||
|
||||
############################################################################
|
||||
# Copy Pytorch CPP Extensions to the local onnxruntime package folder
|
||||
############################################################################
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
from setuptools import setup, Extension
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
filename = os.path.join(os.path.dirname(__file__),
|
||||
'torch_interop_utils.cc')
|
||||
setup(name='torch_interop_utils',
|
||||
ext_modules=[cpp_extension.CppExtension(name='torch_interop_utils',
|
||||
sources=[filename])],
|
||||
cmdclass={'build_ext': cpp_extension.BuildExtension})
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#include <torch/extension.h>
|
||||
#include <torch/csrc/autograd/function.h>
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
|
||||
// In Torch forward run (e.g. THPVariable_apply), ctx of type THPFunction* (which is also a PyObject*)
|
||||
// is created. The ctx is used to run user-defined forward function and backward function as the first
|
||||
// parameter. The same time, a cdata of type std::shared_ptr<PyNode> is created, cdata is owned by:
|
||||
// a). forward run output tensors as grad_fn_ property. (The full hierarchy is: Tensor own
|
||||
// shared_pointer<TensorImpl>; TensorImpl owns std::unique_ptr<AutogradMeta>; AutogradMeta
|
||||
// manages grad_/grad_fn_/grad_accumulator_. Among them, grad_fn_ is std::shared_ptr<PyNode>,
|
||||
// the so called gradient function.)
|
||||
// b). the consumer operator of forward run outputs, will let its own PyNode/Node own the grad_fn_
|
||||
// (of type std::shared_ptr<PyNode>) of all inputs that require grad.
|
||||
// BUT, if we run torch computation within PythonOp, b) is lost. SO, for some cases, where forward outputs
|
||||
// are not used and freed before backward function runs, the grad_fn_ (std::shared_ptr<PyNode>) references
|
||||
// in a) will be released. Without b)'s reference, grad_fn_ release PyNode as reference count reach 0;
|
||||
// Then when PythonOpGrad runs, segment fault.
|
||||
//
|
||||
// So we add b)'s reference in this Pool when forward run returns; dereference from this Pool when backward
|
||||
// completes, then ~PyNode() is called, which subsquently calls ~THPFunction() destorying ctx.
|
||||
class PyNodeSharedPointerPool {
|
||||
public:
|
||||
static PyNodeSharedPointerPool& GetInstance() {
|
||||
static PyNodeSharedPointerPool pool;
|
||||
return pool;
|
||||
};
|
||||
|
||||
void RegisterGradFunc(const size_t& ctx_address, torch::autograd::AutogradMeta* autograd_meta){
|
||||
auto it = grad_fns_.find(ctx_address);
|
||||
TORCH_CHECK(it == grad_fns_.end(), "should not register grad_fn twice for ctx ", ctx_address);
|
||||
|
||||
// Add new entry if key hasn't been registered.
|
||||
grad_fns_.emplace(ctx_address, std::move(autograd_meta->grad_fn_));
|
||||
};
|
||||
|
||||
void UnRegisterGradFunc(const size_t& ctx_address){
|
||||
auto it = grad_fns_.find(ctx_address);
|
||||
TORCH_CHECK(it != grad_fns_.end(), "fail to find grad_fn for ctx ", ctx_address);
|
||||
|
||||
grad_fns_.erase(ctx_address);
|
||||
};
|
||||
|
||||
private:
|
||||
PyNodeSharedPointerPool(){};
|
||||
~PyNodeSharedPointerPool(){};
|
||||
|
||||
PyNodeSharedPointerPool(const PyNodeSharedPointerPool&) = delete;
|
||||
PyNodeSharedPointerPool& operator=(const PyNodeSharedPointerPool&) = delete;
|
||||
PyNodeSharedPointerPool(PyNodeSharedPointerPool&&) = delete;
|
||||
PyNodeSharedPointerPool& operator=(PyNodeSharedPointerPool&&) = delete;
|
||||
|
||||
std::unordered_map<size_t, std::shared_ptr<torch::autograd::Node>> grad_fns_;
|
||||
};
|
||||
|
||||
|
||||
void register_grad_fn(size_t ctx_address, at::Tensor target)
|
||||
{
|
||||
torch::autograd::AutogradMeta* autograd_meta = torch::autograd::impl::get_autograd_meta(target);
|
||||
PyNodeSharedPointerPool::GetInstance().RegisterGradFunc(ctx_address, autograd_meta);
|
||||
}
|
||||
|
||||
void unregister_grad_fn(size_t ctx_address)
|
||||
{
|
||||
PyNodeSharedPointerPool::GetInstance().UnRegisterGradFunc(ctx_address);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("register_grad_fn", ®ister_grad_fn, "increase grad_fn shared pointer reference.");
|
||||
m.def("unregister_grad_fn", &unregister_grad_fn, "release grad_fn shared pointer referece.");
|
||||
}
|
||||
|
|
@ -74,6 +74,68 @@ def test_GeLU():
|
|||
run_training_test_and_compare(model_builder, input_generator, label_input)
|
||||
|
||||
|
||||
def test_GeLU_custom_func_rets_not_as_module_output():
|
||||
@torch.jit.script
|
||||
def bias_gelu(bias, y):
|
||||
x = bias + y
|
||||
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
||||
|
||||
@torch.jit.script
|
||||
def bias_gelu_backward(g, bias, y):
|
||||
x = bias + y
|
||||
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 +
|
||||
0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
||||
return ff*g
|
||||
|
||||
class GeLUFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, bias):
|
||||
ctx.save_for_backward(input, bias)
|
||||
return bias_gelu(bias, input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, bias = ctx.saved_tensors
|
||||
tmp = bias_gelu_backward(grad_output, bias, input)
|
||||
return tmp, tmp
|
||||
|
||||
class GeLUModel(torch.nn.Module):
|
||||
def __init__(self, output_size):
|
||||
super(GeLUModel, self).__init__()
|
||||
self.relu = GeLUFunction.apply
|
||||
self.bias = Parameter(torch.empty(
|
||||
output_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.float))
|
||||
|
||||
with torch.no_grad():
|
||||
self.bias.uniform_()
|
||||
|
||||
def forward(self, model_input):
|
||||
out = self.relu(model_input, self.bias)
|
||||
# add * 9 by intention to make custom function's output
|
||||
# NOT as module outputs (which are consumed by subsquent computations).
|
||||
# This aims to trigger a GC for "out", saying, out is released,
|
||||
# the underlying std::shared<PyNode> still have other references.
|
||||
# Otherwise, a segementfault will be triggered.
|
||||
out = out * 9
|
||||
return out
|
||||
|
||||
output_size = 1024
|
||||
|
||||
def model_builder():
|
||||
return GeLUModel(output_size)
|
||||
|
||||
def input_generator():
|
||||
return torch.randn(output_size, dtype=torch.float)
|
||||
|
||||
# generate a label that have same shape as forward output.
|
||||
label_input = torch.ones([output_size])
|
||||
|
||||
run_training_test_and_compare(model_builder, input_generator, label_input)
|
||||
|
||||
|
||||
def test_MegatronF():
|
||||
# MegatronGFunction is tested in distributed test files.
|
||||
class MegatronFFunction(torch.autograd.Function):
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -323,8 +323,10 @@ if enable_training:
|
|||
'onnxruntime.training.ortmodule.experimental.json_config',
|
||||
'onnxruntime.training.ortmodule.torch_cpp_extensions',
|
||||
'onnxruntime.training.ortmodule.torch_cpp_extensions.aten_op_executor',
|
||||
'onnxruntime.training.ortmodule.torch_cpp_extensions.torch_interop_utils',
|
||||
'onnxruntime.training.ortmodule.torch_cpp_extensions.torch_gpu_allocator'])
|
||||
package_data['onnxruntime.training.ortmodule.torch_cpp_extensions.aten_op_executor'] = ['*.cc']
|
||||
package_data['onnxruntime.training.ortmodule.torch_cpp_extensions.torch_interop_utils'] = ['*.cc']
|
||||
package_data['onnxruntime.training.ortmodule.torch_cpp_extensions.torch_gpu_allocator'] = ['*.cc']
|
||||
requirements_file = "requirements-training.txt"
|
||||
# with training, we want to follow this naming convention:
|
||||
|
|
|
|||
Loading…
Reference in a new issue