diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py index 9ec6578ba1..90aceb2d11 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py @@ -15,10 +15,12 @@ class Enabler(object): def state(self, val): self._state = val + custom_autograd_function_enabler = Enabler() -# Initialize static objects needed to run custom autograd.Function's. def enable_custom_autograd_support(): + # Initialize static objects needed to run custom autograd.Function's. + from onnxruntime.capi._pybind_state import register_forward_runner, register_backward_runner, unregister_python_functions from torch.onnx import register_custom_op_symbolic from ._custom_autograd_function_exporter import _export @@ -31,6 +33,12 @@ def enable_custom_autograd_support(): # Unregister all python functions automatically upon normal interpreter termination. atexit.register(unregister_python_functions) - register_custom_op_symbolic('::prim_PythonOp', _export, 1) + try: + # This is for the latest Pytorch nightly after this commit: + # https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec + register_custom_op_symbolic('prim::PythonOp', _export, 1) + except: + # This applies to Pytorch 1.9 and 1.9.1. + register_custom_op_symbolic('::prim_PythonOp', _export, 1) custom_autograd_function_enabler.state = True