diff --git a/orttraining/orttraining/python/training/optim/fp16_optimizer.py b/orttraining/orttraining/python/training/optim/fp16_optimizer.py index c3864ea711..ce5fb8e09c 100644 --- a/orttraining/orttraining/python/training/optim/fp16_optimizer.py +++ b/orttraining/orttraining/python/training/optim/fp16_optimizer.py @@ -14,7 +14,7 @@ def FP16_Optimizer(optimizer, **kwargs): Apex, DeepSpeed, Megatron-LM. Usage: - 1. DeepSpeed ZeRO Optimizer Override: + 1. DeepSpeed ZeRO Optimizer Override: >>> from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer >>> optimizer = Adam(param_groups, diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py index 19d08d68a6..6071bd0f8b 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py @@ -18,6 +18,7 @@ class Enabler(object): @property def already_enabled(self): + # Once enabled, this flag will be True. return self._already_enabled @state.setter @@ -30,7 +31,7 @@ class Enabler(object): custom_autograd_function_enabler = Enabler() # Legacy API to enable the custom autograd, keep its name with default value for compatibility. -def enable_custom_autograd_support(enable=True): +def enable_custom_autograd_support(to_enable=True): import atexit @@ -43,12 +44,13 @@ def enable_custom_autograd_support(enable=True): ) from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils - from ._custom_autograd_function_exporter import _export - from ._custom_autograd_function_runner import call_python_backward_function, call_python_forward_function + from ._custom_autograd_function_exporter import _clear_nontensor_object_references, _export - if enable is True: + if to_enable is True and custom_autograd_function_enabler.state is False: if custom_autograd_function_enabler.already_enabled is False: # Initialize static objects needed to run custom autograd.Function's. + from ._custom_autograd_function_runner import call_python_backward_function, call_python_forward_function + register_forward_runner(call_python_forward_function) register_backward_runner(call_python_backward_function) @@ -57,6 +59,8 @@ def enable_custom_autograd_support(enable=True): # Clear all gradient functions, to avoid a deadlock issue. # Check the called function for more detailed comments. atexit.register(torch_interop_utils.clear_all_grad_fns) + # Clear all non-tensor object reference (for example, ProcessGroup passed to PythonOp). + atexit.register(_clear_nontensor_object_references) try: # This is for the latest Pytorch nightly after this commit: @@ -67,17 +71,16 @@ def enable_custom_autograd_support(enable=True): register_custom_op_symbolic("::prim_PythonOp", _export, 1) custom_autograd_function_enabler.state = True - else: - if custom_autograd_function_enabler.already_enabled is True: - # We don't need remove the registered runner because it won't be used if we disable the feature. - # But we need unregister the PythonOp custom operator function. - try: - # This is for the latest Pytorch nightly after this commit: - # https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec - unregister_custom_op_symbolic("prim::PythonOp", 1) - except: - # This applies to Pytorch 1.9 and 1.9.1. - unregister_custom_op_symbolic("::prim_PythonOp", 1) + elif to_enable is False and custom_autograd_function_enabler.state is True: + # We don't need remove the registered runner because it won't be used if we disable the feature. + # But we need unregister the PythonOp custom operator function. + try: + # This is for the latest Pytorch nightly after this commit: + # https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec + unregister_custom_op_symbolic("prim::PythonOp", 1) + except: + # This applies to Pytorch 1.9 and 1.9.1. + unregister_custom_op_symbolic("::prim_PythonOp", 1) custom_autograd_function_enabler.state = False diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index ebb43e2524..598a19ee05 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -24,6 +24,14 @@ from ._fallback import ORTModuleONNXModelException, ORTModuleTorchModelException # at all. BANNED_AUTOGRAD_FUNCTION_NAMES = set([torch.utils.checkpoint.CheckpointFunction.__name__]) +# For pointer needed for PythonOp execution, we firstly append it into a global store to hold a +# reference (in case it is released after module exported). +NONTENSOR_OBJECT_POINTER_STORE = {} + + +def _clear_nontensor_object_references(): + NONTENSOR_OBJECT_POINTER_STORE.clear() + def _export_pt_1_10(g, n, *args, **kwargs): """ @@ -111,6 +119,8 @@ def _export_pt_1_10(g, n, *args, **kwargs): # All other inputs are accessed via "pointers". input_pointer_scalar_positions.append(i) input_pointer_scalars.append(id(arg)) + + NONTENSOR_OBJECT_POINTER_STORE[id(arg)] = arg else: raise wrap_exception( ORTModuleONNXModelException,