[ONNX] Adjust and add deprecation messages (#146639)

Adjust and add deprecation messages to torch.onnx utilities and verification methods because they are only related to torch script and are obsolete.

Removed unused `_exporter_states.py` and removed the internal deprecation module in favor of the typing_extensions deprecated decorator.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146639
Approved by: https://github.com/titaiwangms
This commit is contained in:
Justin Chu 2025-02-08 05:09:16 +00:00 committed by PyTorch MergeBot
parent 2328dcccb9
commit 63c2909ae3
7 changed files with 120 additions and 151 deletions

View file

@ -49,6 +49,7 @@ __all__ = [
]
from typing import Any, Callable, TYPE_CHECKING
from typing_extensions import deprecated
import torch
from torch import _C
@ -168,6 +169,19 @@ def export(
) -> ONNXProgram | None:
r"""Exports a model into ONNX format.
.. versionchanged:: 2.6
*training* is now deprecated. Instead, set the training mode of the model before exporting.
.. versionchanged:: 2.6
*operator_export_type* is now deprecated. Only ONNX is supported.
.. versionchanged:: 2.6
*do_constant_folding* is now deprecated. It is always enabled.
.. versionchanged:: 2.6
*export_modules_as_functions* is now deprecated.
.. versionchanged:: 2.6
*autograd_inlining* is now deprecated.
.. versionchanged:: 2.7
*optimize* is now True by default.
Args:
model: The model to be exported.
args: Example positional inputs. Any non-Tensor arguments will be hard-coded into the
@ -342,6 +356,9 @@ def export(
autograd_inlining: Deprecated.
Flag used to control whether to inline autograd functions.
Refer to https://github.com/pytorch/pytorch/pull/74765 for more details.
Returns:
:class:`torch.onnx.ONNXProgram` if dynamo is True, otherwise None.
"""
if dynamo is True or isinstance(model, torch.export.ExportedProgram):
from torch.onnx._internal.exporter import _compat
@ -402,6 +419,9 @@ def export(
return None
@deprecated(
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead."
)
def dynamo_export(
model: torch.nn.Module | Callable | torch.export.ExportedProgram, # type: ignore[name-defined]
/,
@ -411,6 +431,9 @@ def dynamo_export(
) -> ONNXProgram:
"""Export a torch.nn.Module to an ONNX graph.
.. deprecated:: 2.6
Please use ``torch.onnx.export(..., dynamo=True)`` instead.
Args:
model: The PyTorch model to be exported to ONNX.
model_args: Positional inputs to ``model``.
@ -452,7 +475,6 @@ def dynamo_export(
onnx_program.save("my_dynamic_model.onnx")
"""
# NOTE: The new exporter is experimental and is not enabled by default.
import warnings
from torch.onnx import _flags

View file

@ -1,72 +0,0 @@
"""Utility for deprecating functions."""
import functools
import textwrap
import warnings
from typing import Callable, TypeVar
from typing_extensions import ParamSpec
_T = TypeVar("_T")
_P = ParamSpec("_P")
def deprecated(
since: str, removed_in: str, instructions: str
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
"""Marks functions as deprecated.
It will result in a warning when the function is called and a note in the
docstring.
Args:
since: The version when the function was first deprecated.
removed_in: The version when the function will be removed.
instructions: The action users should take.
"""
def decorator(function: Callable[_P, _T]) -> Callable[_P, _T]:
@functools.wraps(function)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
warnings.warn(
f"'{function.__module__}.{function.__name__}' "
f"is deprecated in version {since} and will be "
f"removed in {removed_in}. Please {instructions}.",
category=DeprecationWarning,
stacklevel=2,
)
return function(*args, **kwargs)
# Add a deprecation note to the docstring.
docstring = function.__doc__ or ""
# Add a note to the docstring.
deprecation_note = textwrap.dedent(
f"""\
.. deprecated:: {since}
Deprecated and will be removed in version {removed_in}.
Please {instructions}.
"""
)
# Split docstring at first occurrence of newline
summary_and_body = docstring.split("\n\n", 1)
if len(summary_and_body) > 1:
summary, body = summary_and_body
# Dedent the body. We cannot do this with the presence of the summary because
# the body contains leading whitespaces when the summary does not.
body = textwrap.dedent(body)
new_docstring_parts = [deprecation_note, "\n\n", summary, body]
else:
summary = summary_and_body[0]
new_docstring_parts = [deprecation_note, "\n\n", summary]
wrapper.__doc__ = "".join(new_docstring_parts)
return wrapper
return decorator

View file

@ -1,12 +0,0 @@
from __future__ import annotations
class ExportTypes:
"""Specifies how the ONNX model is stored."""
# TODO(justinchuby): Deprecate and remove this class.
PROTOBUF_FILE = "Saves model in the specified protobuf file."
ZIP_ARCHIVE = "Saves model in the specified ZIP file (uncompressed)."
COMPRESSED_ZIP_ARCHIVE = "Saves model in the specified ZIP file (compressed)."
DIRECTORY = "Saves model in the specified folder."

View file

@ -81,12 +81,14 @@ class ONNXFakeContext:
@deprecated(
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead.",
category=DeprecationWarning,
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead."
)
class OnnxRegistry:
"""Registry for ONNX functions.
.. deprecated:: 2.6
Please use ``torch.onnx.export(..., dynamo=True)`` instead.
The registry maintains a mapping from qualified names to symbolic functions under a
fixed opset version. It supports registering custom onnx-script functions and for
dispatcher to dispatch calls to the appropriate function.
@ -229,12 +231,14 @@ class OnnxRegistry:
@deprecated(
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead.",
category=DeprecationWarning,
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead."
)
class ExportOptions:
"""Options to influence the TorchDynamo ONNX exporter.
.. deprecated:: 2.6
Please use ``torch.onnx.export(..., dynamo=True)`` instead.
Attributes:
dynamic_shapes: Shape information hint for input/output tensors.
When ``None``, the exporter determines the most compatible setting.
@ -385,8 +389,9 @@ def enable_fake_mode():
It is highly recommended to initialize the model in fake mode when exporting models that
are too large to fit into memory.
NOTE: This function does not support torch.onnx.export(..., dynamo=True, optimize=True), so
please call ONNXProgram.optimize() outside of the function after the model is exported.
.. note::
This function does not support torch.onnx.export(..., dynamo=True, optimize=True).
Please call ONNXProgram.optimize() outside of the function after the model is exported.
Example::
@ -443,12 +448,14 @@ def enable_fake_mode():
@deprecated(
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead.",
category=DeprecationWarning,
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead."
)
class ONNXRuntimeOptions:
"""Options to influence the execution of the ONNX model through ONNX Runtime.
.. deprecated:: 2.6
Please use ``torch.onnx.export(..., dynamo=True)`` instead.
Attributes:
session_options: ONNX Runtime session options.
execution_providers: ONNX Runtime execution providers to use during model execution.
@ -701,8 +708,7 @@ def _assert_dependencies(export_options: ResolvedExportOptions):
@deprecated(
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead.",
category=DeprecationWarning,
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead."
)
def dynamo_export(
model: torch.nn.Module | Callable,
@ -713,6 +719,9 @@ def dynamo_export(
) -> _onnx_program.ONNXProgram:
"""Export a torch.nn.Module to an ONNX graph.
.. deprecated:: 2.6
Please use ``torch.onnx.export(..., dynamo=True)`` instead.
Args:
model: The PyTorch model to be exported to ONNX.
model_args: Positional inputs to ``model``.

View file

@ -15,6 +15,7 @@ import math
import sys
import warnings
from typing import Callable, TYPE_CHECKING
from typing_extensions import deprecated
import torch
import torch._C._onnx as _C_onnx
@ -23,7 +24,7 @@ import torch.onnx
from torch import _C
# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
from torch.onnx import _constants, _deprecation, _type_utils, errors, symbolic_helper
from torch.onnx import _constants, _type_utils, errors, symbolic_helper
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import jit_utils, registration
@ -3315,91 +3316,55 @@ def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_co
@_onnx_symbolic("aten::_cast_Byte")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
@deprecated("Avoid using this function and create a Cast node instead")
def _cast_Byte(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.UINT8)
@_onnx_symbolic("aten::_cast_Char")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
@deprecated("Avoid using this function and create a Cast node instead")
def _cast_Char(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT8)
@_onnx_symbolic("aten::_cast_Short")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
@deprecated("Avoid using this function and create a Cast node instead")
def _cast_Short(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT16)
@_onnx_symbolic("aten::_cast_Int")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
@deprecated("Avoid using this function and create a Cast node instead")
def _cast_Int(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32)
@_onnx_symbolic("aten::_cast_Long")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
@deprecated("Avoid using this function and create a Cast node instead")
def _cast_Long(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64)
@_onnx_symbolic("aten::_cast_Half")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
@deprecated("Avoid using this function and create a Cast node instead")
def _cast_Half(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
@_onnx_symbolic("aten::_cast_Float")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
@deprecated("Avoid using this function and create a Cast node instead")
def _cast_Float(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT)
@_onnx_symbolic("aten::_cast_Double")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
@deprecated("Avoid using this function and create a Cast node instead")
def _cast_Double(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE)
@_onnx_symbolic("aten::_cast_Bool")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
@deprecated("Avoid using this function and create a Cast node instead")
def _cast_Bool(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL)

View file

@ -14,13 +14,14 @@ import re
import typing
import warnings
from typing import Any, Callable, cast
from typing_extensions import deprecated
import torch
import torch._C._onnx as _C_onnx
import torch.jit._trace
import torch.serialization
from torch import _C
from torch.onnx import _constants, _deprecation, errors, symbolic_helper # noqa: F401
from torch.onnx import _constants, errors, symbolic_helper # noqa: F401
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import diagnostics, jit_utils, onnx_proto_utils, registration
@ -55,11 +56,15 @@ def is_in_onnx_export() -> bool:
_params_dict = {} # type: ignore[var-annotated]
@deprecated("Please set training mode before exporting the model", category=None)
@contextlib.contextmanager
def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode):
r"""A context manager to temporarily set the training mode of ``model``
"""A context manager to temporarily set the training mode of ``model``
to ``mode``, resetting it when we exit the with-block.
.. deprecated:: 2.7
Please set training mode before exporting the model.
Args:
model: Same type and meaning as ``model`` arg to :func:`export`.
mode: Same type and meaning as ``training`` arg to :func:`export`.
@ -103,8 +108,14 @@ def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode):
model.train(originally_training)
@deprecated("Please remove usage of this function", category=None)
@contextlib.contextmanager
def disable_apex_o2_state_dict_hook(model: torch.nn.Module | torch.jit.ScriptFunction):
"""A context manager to temporarily disable the Apex O2 hook that returns.
.. deprecated:: 2.7
Please remove usage of this function.
"""
# Apex O2 hook state_dict to return fp16 weights as fp32.
# Exporter cannot identify them as same tensors.
# Since this hook is only used by optimizer, it is safe to
@ -134,8 +145,14 @@ def disable_apex_o2_state_dict_hook(model: torch.nn.Module | torch.jit.ScriptFun
pass
@deprecated("Please remove usage of this function")
@contextlib.contextmanager
def setup_onnx_logging(verbose: bool):
"""A context manager to temporarily set the ONNX logging verbosity.
.. deprecated:: 2.7
Please remove usage of this function.
"""
is_originally_enabled = _C._jit_is_onnx_log_enabled
if is_originally_enabled or verbose: # type: ignore[truthy-function]
_C._jit_set_onnx_log_enabled(True)
@ -146,8 +163,15 @@ def setup_onnx_logging(verbose: bool):
_C._jit_set_onnx_log_enabled(False)
@deprecated("Please remove usage of this function", category=None)
@contextlib.contextmanager
def exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool):
"""A context manager to temporarily set the training mode of ``model``
to ``mode``, disable the Apex O2 hook, and set the ONNX logging verbosity.
.. deprecated:: 2.7
Please set training mode before exporting the model.
"""
with select_model_mode_for_export(
model, mode
) as mode_ctx, disable_apex_o2_state_dict_hook(
@ -1153,7 +1177,7 @@ def _model_to_graph(
return graph, params_dict, torch_out
@_deprecation.deprecated("2.5", "the future", "avoid using this function")
@deprecated("Please remove usage of this function")
def unconvertible_ops(
model,
args,
@ -1162,6 +1186,9 @@ def unconvertible_ops(
) -> tuple[_C.Graph, list[str]]:
"""Returns an approximated list of all ops that are yet supported by :mod:`torch.onnx`.
.. deprecated:: 2.5
Please remove usage of this function.
The list is approximated because some ops may be removed during the conversion
process and don't need to be converted. Some other ops may have partial support
that will fail conversion with particular inputs. Please open a Github Issue

View file

@ -44,7 +44,12 @@ _OutputsType = Union[Sequence[_NumericType], Sequence]
class OnnxBackend(enum.Enum):
"""Enum class for ONNX backend used for export verification."""
"""Enum class for ONNX backend used for export verification.
.. deprecated:: 2.7
Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned
``ONNXProgram`` to test the ONNX model.
"""
REFERENCE = "ONNXReferenceEvaluator"
ONNX_RUNTIME_CPU = "CPUExecutionProvider"
@ -55,6 +60,10 @@ class OnnxBackend(enum.Enum):
class VerificationOptions:
"""Options for ONNX export verification.
.. deprecated:: 2.7
Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned
``ONNXProgram`` to test the ONNX model.
Attributes:
flatten: If True, unpack nested list/tuple/dict inputs into a flattened list of
Tensors for ONNX. Set this to False if nested structures are to be preserved
@ -775,7 +784,7 @@ def check_export_model_diff(
@typing_extensions.deprecated(
"torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) "
"and use ONNXProgram to test the ONNX model",
category=DeprecationWarning,
category=None,
)
def verify(
model: _ModelType,
@ -797,6 +806,10 @@ def verify(
):
"""Verify model export to ONNX against original PyTorch model.
.. deprecated:: 2.7
Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned
``ONNXProgram`` to test the ONNX model.
Args:
model (torch.nn.Module or torch.jit.ScriptModule): See :func:`torch.onnx.export`.
input_args (tuple): See :func:`torch.onnx.export`.
@ -866,8 +879,7 @@ def verify(
@typing_extensions.deprecated(
"torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) "
"and use ONNXProgram to test the ONNX model",
category=DeprecationWarning,
"and use ONNXProgram to test the ONNX model"
)
def verify_aten_graph(
graph: torch.Graph,
@ -876,6 +888,12 @@ def verify_aten_graph(
params_dict: dict[str, Any] | None = None,
verification_options: VerificationOptions | None = None,
) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]:
"""Verify aten graph export to ONNX against original PyTorch model.
.. deprecated:: 2.7
Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned
``ONNXProgram`` to test the ONNX model.
"""
if verification_options is None:
verification_options = VerificationOptions()
if params_dict is None:
@ -1161,12 +1179,16 @@ class OnnxTestCaseRepro:
@typing_extensions.deprecated(
"torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) "
"and use ONNXProgram to test the ONNX model",
category=DeprecationWarning,
"and use ONNXProgram to test the ONNX model"
)
@dataclasses.dataclass
class GraphInfo:
"""GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph."""
"""GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph.
.. deprecated:: 2.7
Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned
``ONNXProgram`` to test the ONNX model.
"""
graph: torch.Graph
input_args: tuple[Any, ...]
@ -1691,6 +1713,10 @@ def _produced_by(value: torch.Value, nodes: Collection[torch.Node]) -> bool:
return value.node() in nodes
@typing_extensions.deprecated(
"torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) "
"and use ONNXProgram to test the ONNX model"
)
def find_mismatch(
model: torch.nn.Module | torch.jit.ScriptModule,
input_args: tuple[Any, ...],
@ -1703,6 +1729,10 @@ def find_mismatch(
) -> GraphInfo:
r"""Find all mismatches between the original model and the exported model.
.. deprecated:: 2.7
Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned
``ONNXProgram`` to test the ONNX model.
Experimental. The API is subject to change.
This tool helps debug the mismatch between the original PyTorch model and exported