Use autograd_inlining for model export (#16665)

### Use autograd_inlining for model export

From some versions of PyTorch, there is an issue related to custom
autograd.Function inlining, even though we register custom export
function for the autograd.Function (e.g. when custom autograd function
is enabled).

As an options, PyTorch exporter adds a new flag during export, we can
disable the inline. https://github.com/pytorch/pytorch/pull/104067

Currently the PyTorch change is in nightly built, this PR dynamically
check the torch.onnx.export's signature and decide to use the
`autograd_inlining` when it exists.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
pengwa 2023-07-12 20:57:24 +08:00 committed by GitHub
parent 596dbe277e
commit 2449ded20f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 1 deletions

View file

@ -33,6 +33,7 @@ from ._gradient_accumulation_manager import GradientAccumulationManager
from ._graph_execution_interface import GraphExecutionInterface
from ._io import _FlattenedModule, _InputInfo, _ModelInputOutputSchemaType
from ._runtime_inspector import RuntimeInspector
from ._utils import check_function_has_param
from .options import DebugOptions, LogLevel, _RuntimeOptions
from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension
@ -335,6 +336,15 @@ class GraphExecutionManager(GraphExecutionInterface):
"export_params": False,
"keep_initializers_as_inputs": True,
}
if check_function_has_param(torch.onnx.export, "autograd_inlining"):
# From some PyTorch version, autograd_inlining is a valid argument.
# We allow it to be True if custom autograd function is disabled (where autograd.Function
# anyway is not supported in ONNX until it can be inlined).
required_export_kwargs[
"autograd_inlining"
] = not self._runtime_options.enable_custom_autograd_function
invalid_args = self._export_extra_kwargs.keys() & required_export_kwargs.keys()
assert (
len(invalid_args) == 0

View file

@ -12,7 +12,7 @@ import os
import random
import traceback
import types
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch
@ -419,3 +419,7 @@ def get_runtime_pytorch_version():
from packaging import version
return version.parse(torch.__version__.split("+")[0])
def check_function_has_param(function: Callable, param_name: str) -> bool:
return param_name in inspect.signature(function).parameters