mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Support new symbolic function api from PyTorch with PythonOp (#9880)
* Support new symbolic function api from PyTorch with PythonOp * Specify exact exception * add comments * move comments and arg
This commit is contained in:
parent
93636cbd20
commit
102f9b05e1
1 changed files with 12 additions and 1 deletions
|
|
@ -23,7 +23,7 @@ BANNED_AUTOGRAD_FUNCTION_NAMES = set(
|
|||
[torch.utils.checkpoint.CheckpointFunction.__name__])
|
||||
|
||||
|
||||
def _export(g, n, *args, **kwargs):
|
||||
def _export_pt_1_10(g, n, *args, **kwargs):
|
||||
'''
|
||||
This function exports PythonOp (input: "n") into a graph
|
||||
node in "g". "args" and "kwargs" are inputs to that PythonOp.
|
||||
|
|
@ -164,6 +164,17 @@ def _export(g, n, *args, **kwargs):
|
|||
sys.stderr.flush()
|
||||
raise wrap_exception(ORTModuleONNXModelException, e)
|
||||
|
||||
# Starting from PyTorch 1.11, there has been a change to symbolic function signature
|
||||
# in terms of how additional context is accessed. More info at
|
||||
# https://github.com/pytorch/pytorch/blob/6b02648479d3615fa3260961e24f38dd0f22da94/torch/onnx/symbolic_helper.py#L48
|
||||
# This code can be cleaned up once support for PyTorch version < 1.11 is dropped.
|
||||
try:
|
||||
from torch.onnx import SymbolicContext
|
||||
def _export(ctx: SymbolicContext, g, *args, **kwargs):
|
||||
n = ctx.cur_node
|
||||
return _export_pt_1_10(g, n, *args, **kwargs)
|
||||
except ImportError:
|
||||
_export = _export_pt_1_10
|
||||
|
||||
def _post_process_after_export(exported_model, enable_custom_autograd_function, log_level):
|
||||
if enable_custom_autograd_function:
|
||||
|
|
|
|||
Loading…
Reference in a new issue