From 102f9b05e17366e7fa474fb9faed6371c09ac644 Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Thu, 16 Dec 2021 08:08:06 -0800 Subject: [PATCH] Support new symbolic function api from PyTorch with PythonOp (#9880) * Support new symbolic function api from PyTorch with PythonOp * Specify exact exception * add comments * move comments and arg --- .../ortmodule/_custom_autograd_function_exporter.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index e1120d08de..5a3057e8ec 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -23,7 +23,7 @@ BANNED_AUTOGRAD_FUNCTION_NAMES = set( [torch.utils.checkpoint.CheckpointFunction.__name__]) -def _export(g, n, *args, **kwargs): +def _export_pt_1_10(g, n, *args, **kwargs): ''' This function exports PythonOp (input: "n") into a graph node in "g". "args" and "kwargs" are inputs to that PythonOp. @@ -164,6 +164,17 @@ def _export(g, n, *args, **kwargs): sys.stderr.flush() raise wrap_exception(ORTModuleONNXModelException, e) +# Starting from PyTorch 1.11, there has been a change to symbolic function signature +# in terms of how additional context is accessed. More info at +# https://github.com/pytorch/pytorch/blob/6b02648479d3615fa3260961e24f38dd0f22da94/torch/onnx/symbolic_helper.py#L48 +# This code can be cleaned up once support for PyTorch version < 1.11 is dropped. +try: + from torch.onnx import SymbolicContext + def _export(ctx: SymbolicContext, g, *args, **kwargs): + n = ctx.cur_node + return _export_pt_1_10(g, n, *args, **kwargs) +except ImportError: + _export = _export_pt_1_10 def _post_process_after_export(exported_model, enable_custom_autograd_function, log_level): if enable_custom_autograd_function: