Fix segment fault for alltoall (#12701)

* fix segment fault

* formatting
This commit is contained in:
pengwa 2022-08-30 11:27:14 +08:00 committed by GitHub
parent 19ca2a0089
commit a0c25e5c2f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 16 deletions

View file

@ -14,7 +14,7 @@ def FP16_Optimizer(optimizer, **kwargs):
Apex, DeepSpeed, Megatron-LM.
Usage:
1. DeepSpeed ZeRO Optimizer Override
1. DeepSpeed ZeRO Optimizer Override:
>>> from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
>>> optimizer = Adam(param_groups,

View file

@ -18,6 +18,7 @@ class Enabler(object):
@property
def already_enabled(self):
# Once enabled, this flag will be True.
return self._already_enabled
@state.setter
@ -30,7 +31,7 @@ class Enabler(object):
custom_autograd_function_enabler = Enabler()
# Legacy API to enable the custom autograd, keep its name with default value for compatibility.
def enable_custom_autograd_support(enable=True):
def enable_custom_autograd_support(to_enable=True):
import atexit
@ -43,12 +44,13 @@ def enable_custom_autograd_support(enable=True):
)
from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils
from ._custom_autograd_function_exporter import _export
from ._custom_autograd_function_runner import call_python_backward_function, call_python_forward_function
from ._custom_autograd_function_exporter import _clear_nontensor_object_references, _export
if enable is True:
if to_enable is True and custom_autograd_function_enabler.state is False:
if custom_autograd_function_enabler.already_enabled is False:
# Initialize static objects needed to run custom autograd.Function's.
from ._custom_autograd_function_runner import call_python_backward_function, call_python_forward_function
register_forward_runner(call_python_forward_function)
register_backward_runner(call_python_backward_function)
@ -57,6 +59,8 @@ def enable_custom_autograd_support(enable=True):
# Clear all gradient functions, to avoid a deadlock issue.
# Check the called function for more detailed comments.
atexit.register(torch_interop_utils.clear_all_grad_fns)
# Clear all non-tensor object reference (for example, ProcessGroup passed to PythonOp).
atexit.register(_clear_nontensor_object_references)
try:
# This is for the latest Pytorch nightly after this commit:
@ -67,17 +71,16 @@ def enable_custom_autograd_support(enable=True):
register_custom_op_symbolic("::prim_PythonOp", _export, 1)
custom_autograd_function_enabler.state = True
else:
if custom_autograd_function_enabler.already_enabled is True:
# We don't need remove the registered runner because it won't be used if we disable the feature.
# But we need unregister the PythonOp custom operator function.
try:
# This is for the latest Pytorch nightly after this commit:
# https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec
unregister_custom_op_symbolic("prim::PythonOp", 1)
except:
# This applies to Pytorch 1.9 and 1.9.1.
unregister_custom_op_symbolic("::prim_PythonOp", 1)
elif to_enable is False and custom_autograd_function_enabler.state is True:
# We don't need remove the registered runner because it won't be used if we disable the feature.
# But we need unregister the PythonOp custom operator function.
try:
# This is for the latest Pytorch nightly after this commit:
# https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec
unregister_custom_op_symbolic("prim::PythonOp", 1)
except:
# This applies to Pytorch 1.9 and 1.9.1.
unregister_custom_op_symbolic("::prim_PythonOp", 1)
custom_autograd_function_enabler.state = False

View file

@ -24,6 +24,14 @@ from ._fallback import ORTModuleONNXModelException, ORTModuleTorchModelException
# at all.
BANNED_AUTOGRAD_FUNCTION_NAMES = set([torch.utils.checkpoint.CheckpointFunction.__name__])
# For pointer needed for PythonOp execution, we firstly append it into a global store to hold a
# reference (in case it is released after module exported).
NONTENSOR_OBJECT_POINTER_STORE = {}
def _clear_nontensor_object_references():
NONTENSOR_OBJECT_POINTER_STORE.clear()
def _export_pt_1_10(g, n, *args, **kwargs):
"""
@ -111,6 +119,8 @@ def _export_pt_1_10(g, n, *args, **kwargs):
# All other inputs are accessed via "pointers".
input_pointer_scalar_positions.append(i)
input_pointer_scalars.append(id(arg))
NONTENSOR_OBJECT_POINTER_STORE[id(arg)] = arg
else:
raise wrap_exception(
ORTModuleONNXModelException,