From 2449ded20f3b48bfbb8a58d161562f61d3129647 Mon Sep 17 00:00:00 2001 From: pengwa Date: Wed, 12 Jul 2023 20:57:24 +0800 Subject: [PATCH] Use autograd_inlining for model export (#16665) ### Use autograd_inlining for model export From some versions of PyTorch, there is an issue related to custom autograd.Function inlining, even though we register custom export function for the autograd.Function (e.g. when custom autograd function is enabled). As an options, PyTorch exporter adds a new flag during export, we can disable the inline. https://github.com/pytorch/pytorch/pull/104067 Currently the PyTorch change is in nightly built, this PR dynamically check the torch.onnx.export's signature and decide to use the `autograd_inlining` when it exists. ### Motivation and Context --- .../training/ortmodule/_graph_execution_manager.py | 10 ++++++++++ .../orttraining/python/training/ortmodule/_utils.py | 6 +++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 2e256eb241..26036b6cea 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -33,6 +33,7 @@ from ._gradient_accumulation_manager import GradientAccumulationManager from ._graph_execution_interface import GraphExecutionInterface from ._io import _FlattenedModule, _InputInfo, _ModelInputOutputSchemaType from ._runtime_inspector import RuntimeInspector +from ._utils import check_function_has_param from .options import DebugOptions, LogLevel, _RuntimeOptions from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension @@ -335,6 +336,15 @@ class GraphExecutionManager(GraphExecutionInterface): "export_params": False, "keep_initializers_as_inputs": True, } + + if check_function_has_param(torch.onnx.export, "autograd_inlining"): + # From some PyTorch version, autograd_inlining is a valid argument. + # We allow it to be True if custom autograd function is disabled (where autograd.Function + # anyway is not supported in ONNX until it can be inlined). + required_export_kwargs[ + "autograd_inlining" + ] = not self._runtime_options.enable_custom_autograd_function + invalid_args = self._export_extra_kwargs.keys() & required_export_kwargs.keys() assert ( len(invalid_args) == 0 diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index e10b31a086..3dff18b7b7 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -12,7 +12,7 @@ import os import random import traceback import types -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -419,3 +419,7 @@ def get_runtime_pytorch_version(): from packaging import version return version.parse(torch.__version__.split("+")[0]) + + +def check_function_has_param(function: Callable, param_name: str) -> bool: + return param_name in inspect.signature(function).parameters