mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[ONNX] Create deprecation warning on dynamo_export (#146425)
Reland #146003 Deprecation of `torch.onnx.dynamo_export`: * [`torch/onnx/_internal/_exporter_legacy.py`]: Added deprecation warnings to the `OnnxRegistry`, `ExportOptions`, `ONNXRuntimeOptions`, and `dynamo_export` functions, indicating that `torch.onnx.dynamo_export` is deprecated since version 2.6.0 and should be replaced with `torch.onnx.export(..., dynamo=True)`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146425 Approved by: https://github.com/titaiwangms, https://github.com/atalman
This commit is contained in:
parent
fa0592b568
commit
41e6d189a3
7 changed files with 37 additions and 22 deletions
|
|
@ -341,22 +341,6 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
|
|||
f = io.BytesIO()
|
||||
torch.onnx.export(foo, (torch.zeros(1, 2, 3)), f)
|
||||
|
||||
def test_listconstruct_erasure(self):
|
||||
class FooMod(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
mask = x < 0.0
|
||||
return x[mask]
|
||||
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(
|
||||
FooMod(),
|
||||
(torch.rand(3, 4),),
|
||||
f,
|
||||
add_node_names=False,
|
||||
do_constant_folding=False,
|
||||
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
|
||||
)
|
||||
|
||||
def test_export_dynamic_slice(self):
|
||||
class DynamicSliceExportMod(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
|
|
|
|||
|
|
@ -165,7 +165,6 @@ def export(
|
|||
custom_opsets: Mapping[str, int] | None = None,
|
||||
export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False,
|
||||
autograd_inlining: bool = True,
|
||||
**_: Any, # ignored options
|
||||
) -> ONNXProgram | None:
|
||||
r"""Exports a model into ONNX format.
|
||||
|
||||
|
|
@ -477,7 +476,7 @@ def dynamo_export(
|
|||
"You are using an experimental ONNX export logic, which currently only supports dynamic shapes. "
|
||||
"For a more comprehensive set of export options, including advanced features, please consider using "
|
||||
"`torch.onnx.export(..., dynamo=True)`. ",
|
||||
category=FutureWarning,
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
|
||||
if export_options is not None and export_options.dynamic_shapes:
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ def deprecated(
|
|||
f"'{function.__module__}.{function.__name__}' "
|
||||
f"is deprecated in version {since} and will be "
|
||||
f"removed in {removed_in}. Please {instructions}.",
|
||||
category=FutureWarning,
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return function(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import logging
|
|||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, TYPE_CHECKING, TypeVar
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
import torch._ops
|
||||
|
|
@ -79,6 +80,10 @@ class ONNXFakeContext:
|
|||
"""List of paths of files that contain the model :meth:`state_dict`"""
|
||||
|
||||
|
||||
@deprecated(
|
||||
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead.",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
class OnnxRegistry:
|
||||
"""Registry for ONNX functions.
|
||||
|
||||
|
|
@ -223,6 +228,10 @@ class OnnxRegistry:
|
|||
}
|
||||
|
||||
|
||||
@deprecated(
|
||||
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead.",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
class ExportOptions:
|
||||
"""Options to influence the TorchDynamo ONNX exporter.
|
||||
|
||||
|
|
@ -433,6 +442,10 @@ def enable_fake_mode():
|
|||
) # type: ignore[assignment]
|
||||
|
||||
|
||||
@deprecated(
|
||||
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead.",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
class ONNXRuntimeOptions:
|
||||
"""Options to influence the execution of the ONNX model through ONNX Runtime.
|
||||
|
||||
|
|
@ -687,6 +700,10 @@ def _assert_dependencies(export_options: ResolvedExportOptions):
|
|||
raise missing_opset("onnxscript")
|
||||
|
||||
|
||||
@deprecated(
|
||||
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead.",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
def dynamo_export(
|
||||
model: torch.nn.Module | Callable,
|
||||
/,
|
||||
|
|
|
|||
|
|
@ -218,7 +218,6 @@ def export_compat(
|
|||
dump_exported_program: bool = False,
|
||||
artifacts_dir: str | os.PathLike = ".",
|
||||
fallback: bool = False,
|
||||
**_,
|
||||
) -> _onnx_program.ONNXProgram:
|
||||
if opset_version is None:
|
||||
opset_version = onnxscript_apis.torchlib_opset_version()
|
||||
|
|
|
|||
|
|
@ -482,14 +482,14 @@ def export(
|
|||
warnings.warn(
|
||||
"Setting `operator_export_type` to something other than default is deprecated. "
|
||||
"The option will be removed in a future release.",
|
||||
category=FutureWarning,
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
if training == _C_onnx.TrainingMode.TRAINING:
|
||||
warnings.warn(
|
||||
"Setting `training` to something other than default is deprecated. "
|
||||
"The option will be removed in a future release. Please set the training mode "
|
||||
"before exporting the model.",
|
||||
category=FutureWarning,
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
|
||||
args = (args,) if isinstance(args, torch.Tensor) else args
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ import io
|
|||
import itertools
|
||||
import os
|
||||
import tempfile
|
||||
import typing_extensions
|
||||
import warnings
|
||||
from collections.abc import Collection, Mapping, Sequence
|
||||
from typing import Any, Callable, Union
|
||||
|
|
@ -771,6 +772,11 @@ def check_export_model_diff(
|
|||
)
|
||||
|
||||
|
||||
@typing_extensions.deprecated(
|
||||
"torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) "
|
||||
"and use ONNXProgram to test the ONNX model",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
def verify(
|
||||
model: _ModelType,
|
||||
input_args: _InputArgsType,
|
||||
|
|
@ -858,6 +864,11 @@ def verify(
|
|||
)
|
||||
|
||||
|
||||
@typing_extensions.deprecated(
|
||||
"torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) "
|
||||
"and use ONNXProgram to test the ONNX model",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
def verify_aten_graph(
|
||||
graph: torch.Graph,
|
||||
input_args: tuple[Any, ...],
|
||||
|
|
@ -1148,6 +1159,11 @@ class OnnxTestCaseRepro:
|
|||
_compare_onnx_pytorch_outputs_in_np(run_outputs, expected_outs, options)
|
||||
|
||||
|
||||
@typing_extensions.deprecated(
|
||||
"torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) "
|
||||
"and use ONNXProgram to test the ONNX model",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
@dataclasses.dataclass
|
||||
class GraphInfo:
|
||||
"""GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph."""
|
||||
|
|
|
|||
Loading…
Reference in a new issue