diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 2e256eb241..26036b6cea 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -33,6 +33,7 @@ from ._gradient_accumulation_manager import GradientAccumulationManager from ._graph_execution_interface import GraphExecutionInterface from ._io import _FlattenedModule, _InputInfo, _ModelInputOutputSchemaType from ._runtime_inspector import RuntimeInspector +from ._utils import check_function_has_param from .options import DebugOptions, LogLevel, _RuntimeOptions from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension @@ -335,6 +336,15 @@ class GraphExecutionManager(GraphExecutionInterface): "export_params": False, "keep_initializers_as_inputs": True, } + + if check_function_has_param(torch.onnx.export, "autograd_inlining"): + # From some PyTorch version, autograd_inlining is a valid argument. + # We allow it to be True if custom autograd function is disabled (where autograd.Function + # anyway is not supported in ONNX until it can be inlined). + required_export_kwargs[ + "autograd_inlining" + ] = not self._runtime_options.enable_custom_autograd_function + invalid_args = self._export_extra_kwargs.keys() & required_export_kwargs.keys() assert ( len(invalid_args) == 0 diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index e10b31a086..3dff18b7b7 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -12,7 +12,7 @@ import os import random import traceback import types -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -419,3 +419,7 @@ def get_runtime_pytorch_version(): from packaging import version return version.parse(torch.__version__.split("+")[0]) + + +def check_function_has_param(function: Callable, param_name: str) -> bool: + return param_name in inspect.signature(function).parameters