From beddbdec5a100f1b94ae68d0c56d6c229fabb33b Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 22 Oct 2021 10:45:45 -0700 Subject: [PATCH] Fix PythonOp exporter (#9318) Register PythonOp exporter with the right symbol. --- .../training/ortmodule/_custom_autograd_function.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py index 9ec6578ba1..90aceb2d11 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py @@ -15,10 +15,12 @@ class Enabler(object): def state(self, val): self._state = val + custom_autograd_function_enabler = Enabler() -# Initialize static objects needed to run custom autograd.Function's. def enable_custom_autograd_support(): + # Initialize static objects needed to run custom autograd.Function's. + from onnxruntime.capi._pybind_state import register_forward_runner, register_backward_runner, unregister_python_functions from torch.onnx import register_custom_op_symbolic from ._custom_autograd_function_exporter import _export @@ -31,6 +33,12 @@ def enable_custom_autograd_support(): # Unregister all python functions automatically upon normal interpreter termination. atexit.register(unregister_python_functions) - register_custom_op_symbolic('::prim_PythonOp', _export, 1) + try: + # This is for the latest Pytorch nightly after this commit: + # https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec + register_custom_op_symbolic('prim::PythonOp', _export, 1) + except: + # This applies to Pytorch 1.9 and 1.9.1. + register_custom_op_symbolic('::prim_PythonOp', _export, 1) custom_autograd_function_enabler.state = True