Support new symbolic function api from PyTorch with PythonOp (#9880)

* Support new symbolic function api from PyTorch with PythonOp

* Specify exact exception

* add comments

* move comments and arg
This commit is contained in:
Bowen Bao 2021-12-16 08:08:06 -08:00 committed by GitHub
parent 93636cbd20
commit 102f9b05e1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -23,7 +23,7 @@ BANNED_AUTOGRAD_FUNCTION_NAMES = set(
[torch.utils.checkpoint.CheckpointFunction.__name__])
def _export(g, n, *args, **kwargs):
def _export_pt_1_10(g, n, *args, **kwargs):
'''
This function exports PythonOp (input: "n") into a graph
node in "g". "args" and "kwargs" are inputs to that PythonOp.
@ -164,6 +164,17 @@ def _export(g, n, *args, **kwargs):
sys.stderr.flush()
raise wrap_exception(ORTModuleONNXModelException, e)
# Starting from PyTorch 1.11, there has been a change to symbolic function signature
# in terms of how additional context is accessed. More info at
# https://github.com/pytorch/pytorch/blob/6b02648479d3615fa3260961e24f38dd0f22da94/torch/onnx/symbolic_helper.py#L48
# This code can be cleaned up once support for PyTorch version < 1.11 is dropped.
try:
from torch.onnx import SymbolicContext
def _export(ctx: SymbolicContext, g, *args, **kwargs):
n = ctx.cur_node
return _export_pt_1_10(g, n, *args, **kwargs)
except ImportError:
_export = _export_pt_1_10
def _post_process_after_export(exported_model, enable_custom_autograd_function, log_level):
if enable_custom_autograd_function: