Fix PythonOp exporter (#9318)

Register PythonOp exporter with the right symbol.
This commit is contained in:
Wei-Sheng Chin 2021-10-22 10:45:45 -07:00 committed by GitHub
parent 5adf175847
commit beddbdec5a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -15,10 +15,12 @@ class Enabler(object):
def state(self, val):
self._state = val
custom_autograd_function_enabler = Enabler()
# Initialize static objects needed to run custom autograd.Function's.
def enable_custom_autograd_support():
# Initialize static objects needed to run custom autograd.Function's.
from onnxruntime.capi._pybind_state import register_forward_runner, register_backward_runner, unregister_python_functions
from torch.onnx import register_custom_op_symbolic
from ._custom_autograd_function_exporter import _export
@ -31,6 +33,12 @@ def enable_custom_autograd_support():
# Unregister all python functions automatically upon normal interpreter termination.
atexit.register(unregister_python_functions)
register_custom_op_symbolic('::prim_PythonOp', _export, 1)
try:
# This is for the latest Pytorch nightly after this commit:
# https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec
register_custom_op_symbolic('prim::PythonOp', _export, 1)
except:
# This applies to Pytorch 1.9 and 1.9.1.
register_custom_op_symbolic('::prim_PythonOp', _export, 1)
custom_autograd_function_enabler.state = True