mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
Fix segment fault for alltoall (#12701)
* fix segment fault * formatting
This commit is contained in:
parent
19ca2a0089
commit
a0c25e5c2f
3 changed files with 29 additions and 16 deletions
|
|
@ -14,7 +14,7 @@ def FP16_Optimizer(optimizer, **kwargs):
|
|||
Apex, DeepSpeed, Megatron-LM.
|
||||
|
||||
Usage:
|
||||
1. DeepSpeed ZeRO Optimizer Override:
|
||||
1. DeepSpeed ZeRO Optimizer Override:
|
||||
|
||||
>>> from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
|
||||
>>> optimizer = Adam(param_groups,
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ class Enabler(object):
|
|||
|
||||
@property
|
||||
def already_enabled(self):
|
||||
# Once enabled, this flag will be True.
|
||||
return self._already_enabled
|
||||
|
||||
@state.setter
|
||||
|
|
@ -30,7 +31,7 @@ class Enabler(object):
|
|||
custom_autograd_function_enabler = Enabler()
|
||||
|
||||
# Legacy API to enable the custom autograd, keep its name with default value for compatibility.
|
||||
def enable_custom_autograd_support(enable=True):
|
||||
def enable_custom_autograd_support(to_enable=True):
|
||||
|
||||
import atexit
|
||||
|
||||
|
|
@ -43,12 +44,13 @@ def enable_custom_autograd_support(enable=True):
|
|||
)
|
||||
from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils
|
||||
|
||||
from ._custom_autograd_function_exporter import _export
|
||||
from ._custom_autograd_function_runner import call_python_backward_function, call_python_forward_function
|
||||
from ._custom_autograd_function_exporter import _clear_nontensor_object_references, _export
|
||||
|
||||
if enable is True:
|
||||
if to_enable is True and custom_autograd_function_enabler.state is False:
|
||||
if custom_autograd_function_enabler.already_enabled is False:
|
||||
# Initialize static objects needed to run custom autograd.Function's.
|
||||
from ._custom_autograd_function_runner import call_python_backward_function, call_python_forward_function
|
||||
|
||||
register_forward_runner(call_python_forward_function)
|
||||
register_backward_runner(call_python_backward_function)
|
||||
|
||||
|
|
@ -57,6 +59,8 @@ def enable_custom_autograd_support(enable=True):
|
|||
# Clear all gradient functions, to avoid a deadlock issue.
|
||||
# Check the called function for more detailed comments.
|
||||
atexit.register(torch_interop_utils.clear_all_grad_fns)
|
||||
# Clear all non-tensor object reference (for example, ProcessGroup passed to PythonOp).
|
||||
atexit.register(_clear_nontensor_object_references)
|
||||
|
||||
try:
|
||||
# This is for the latest Pytorch nightly after this commit:
|
||||
|
|
@ -67,17 +71,16 @@ def enable_custom_autograd_support(enable=True):
|
|||
register_custom_op_symbolic("::prim_PythonOp", _export, 1)
|
||||
|
||||
custom_autograd_function_enabler.state = True
|
||||
else:
|
||||
if custom_autograd_function_enabler.already_enabled is True:
|
||||
# We don't need remove the registered runner because it won't be used if we disable the feature.
|
||||
# But we need unregister the PythonOp custom operator function.
|
||||
try:
|
||||
# This is for the latest Pytorch nightly after this commit:
|
||||
# https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec
|
||||
unregister_custom_op_symbolic("prim::PythonOp", 1)
|
||||
except:
|
||||
# This applies to Pytorch 1.9 and 1.9.1.
|
||||
unregister_custom_op_symbolic("::prim_PythonOp", 1)
|
||||
elif to_enable is False and custom_autograd_function_enabler.state is True:
|
||||
# We don't need remove the registered runner because it won't be used if we disable the feature.
|
||||
# But we need unregister the PythonOp custom operator function.
|
||||
try:
|
||||
# This is for the latest Pytorch nightly after this commit:
|
||||
# https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec
|
||||
unregister_custom_op_symbolic("prim::PythonOp", 1)
|
||||
except:
|
||||
# This applies to Pytorch 1.9 and 1.9.1.
|
||||
unregister_custom_op_symbolic("::prim_PythonOp", 1)
|
||||
|
||||
custom_autograd_function_enabler.state = False
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,14 @@ from ._fallback import ORTModuleONNXModelException, ORTModuleTorchModelException
|
|||
# at all.
|
||||
BANNED_AUTOGRAD_FUNCTION_NAMES = set([torch.utils.checkpoint.CheckpointFunction.__name__])
|
||||
|
||||
# For pointer needed for PythonOp execution, we firstly append it into a global store to hold a
|
||||
# reference (in case it is released after module exported).
|
||||
NONTENSOR_OBJECT_POINTER_STORE = {}
|
||||
|
||||
|
||||
def _clear_nontensor_object_references():
|
||||
NONTENSOR_OBJECT_POINTER_STORE.clear()
|
||||
|
||||
|
||||
def _export_pt_1_10(g, n, *args, **kwargs):
|
||||
"""
|
||||
|
|
@ -111,6 +119,8 @@ def _export_pt_1_10(g, n, *args, **kwargs):
|
|||
# All other inputs are accessed via "pointers".
|
||||
input_pointer_scalar_positions.append(i)
|
||||
input_pointer_scalars.append(id(arg))
|
||||
|
||||
NONTENSOR_OBJECT_POINTER_STORE[id(arg)] = arg
|
||||
else:
|
||||
raise wrap_exception(
|
||||
ORTModuleONNXModelException,
|
||||
|
|
|
|||
Loading…
Reference in a new issue