mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Fix PythonOp exporter (#9318)
Register PythonOp exporter with the right symbol.
This commit is contained in:
parent
5adf175847
commit
beddbdec5a
1 changed files with 10 additions and 2 deletions
|
|
@ -15,10 +15,12 @@ class Enabler(object):
|
|||
def state(self, val):
|
||||
self._state = val
|
||||
|
||||
|
||||
custom_autograd_function_enabler = Enabler()
|
||||
|
||||
# Initialize static objects needed to run custom autograd.Function's.
|
||||
def enable_custom_autograd_support():
|
||||
# Initialize static objects needed to run custom autograd.Function's.
|
||||
|
||||
from onnxruntime.capi._pybind_state import register_forward_runner, register_backward_runner, unregister_python_functions
|
||||
from torch.onnx import register_custom_op_symbolic
|
||||
from ._custom_autograd_function_exporter import _export
|
||||
|
|
@ -31,6 +33,12 @@ def enable_custom_autograd_support():
|
|||
# Unregister all python functions automatically upon normal interpreter termination.
|
||||
atexit.register(unregister_python_functions)
|
||||
|
||||
register_custom_op_symbolic('::prim_PythonOp', _export, 1)
|
||||
try:
|
||||
# This is for the latest Pytorch nightly after this commit:
|
||||
# https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec
|
||||
register_custom_op_symbolic('prim::PythonOp', _export, 1)
|
||||
except:
|
||||
# This applies to Pytorch 1.9 and 1.9.1.
|
||||
register_custom_op_symbolic('::prim_PythonOp', _export, 1)
|
||||
|
||||
custom_autograd_function_enabler.state = True
|
||||
|
|
|
|||
Loading…
Reference in a new issue