mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Use autograd_inlining for model export (#16665)
### Use autograd_inlining for model export From some versions of PyTorch, there is an issue related to custom autograd.Function inlining, even though we register custom export function for the autograd.Function (e.g. when custom autograd function is enabled). As an options, PyTorch exporter adds a new flag during export, we can disable the inline. https://github.com/pytorch/pytorch/pull/104067 Currently the PyTorch change is in nightly built, this PR dynamically check the torch.onnx.export's signature and decide to use the `autograd_inlining` when it exists. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
596dbe277e
commit
2449ded20f
2 changed files with 15 additions and 1 deletions
|
|
@ -33,6 +33,7 @@ from ._gradient_accumulation_manager import GradientAccumulationManager
|
|||
from ._graph_execution_interface import GraphExecutionInterface
|
||||
from ._io import _FlattenedModule, _InputInfo, _ModelInputOutputSchemaType
|
||||
from ._runtime_inspector import RuntimeInspector
|
||||
from ._utils import check_function_has_param
|
||||
from .options import DebugOptions, LogLevel, _RuntimeOptions
|
||||
from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension
|
||||
|
||||
|
|
@ -335,6 +336,15 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
"export_params": False,
|
||||
"keep_initializers_as_inputs": True,
|
||||
}
|
||||
|
||||
if check_function_has_param(torch.onnx.export, "autograd_inlining"):
|
||||
# From some PyTorch version, autograd_inlining is a valid argument.
|
||||
# We allow it to be True if custom autograd function is disabled (where autograd.Function
|
||||
# anyway is not supported in ONNX until it can be inlined).
|
||||
required_export_kwargs[
|
||||
"autograd_inlining"
|
||||
] = not self._runtime_options.enable_custom_autograd_function
|
||||
|
||||
invalid_args = self._export_extra_kwargs.keys() & required_export_kwargs.keys()
|
||||
assert (
|
||||
len(invalid_args) == 0
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import os
|
|||
import random
|
||||
import traceback
|
||||
import types
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
@ -419,3 +419,7 @@ def get_runtime_pytorch_version():
|
|||
from packaging import version
|
||||
|
||||
return version.parse(torch.__version__.split("+")[0])
|
||||
|
||||
|
||||
def check_function_has_param(function: Callable, param_name: str) -> bool:
|
||||
return param_name in inspect.signature(function).parameters
|
||||
|
|
|
|||
Loading…
Reference in a new issue