From 63c2909ae3e293dee96bca5af88bc51d8ca0ce10 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 8 Feb 2025 05:09:16 +0000 Subject: [PATCH] [ONNX] Adjust and add deprecation messages (#146639) Adjust and add deprecation messages to torch.onnx utilities and verification methods because they are only related to torch script and are obsolete. Removed unused `_exporter_states.py` and removed the internal deprecation module in favor of the typing_extensions deprecated decorator. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146639 Approved by: https://github.com/titaiwangms --- torch/onnx/__init__.py | 24 +++++++- torch/onnx/_deprecation.py | 72 ------------------------ torch/onnx/_exporter_states.py | 12 ---- torch/onnx/_internal/_exporter_legacy.py | 29 ++++++---- torch/onnx/symbolic_opset9.py | 57 ++++--------------- torch/onnx/utils.py | 33 ++++++++++- torch/onnx/verification.py | 44 ++++++++++++--- 7 files changed, 120 insertions(+), 151 deletions(-) delete mode 100644 torch/onnx/_deprecation.py delete mode 100644 torch/onnx/_exporter_states.py diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index b7ab2e6cf72..d9296d4915b 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -49,6 +49,7 @@ __all__ = [ ] from typing import Any, Callable, TYPE_CHECKING +from typing_extensions import deprecated import torch from torch import _C @@ -168,6 +169,19 @@ def export( ) -> ONNXProgram | None: r"""Exports a model into ONNX format. + .. versionchanged:: 2.6 + *training* is now deprecated. Instead, set the training mode of the model before exporting. + .. versionchanged:: 2.6 + *operator_export_type* is now deprecated. Only ONNX is supported. + .. versionchanged:: 2.6 + *do_constant_folding* is now deprecated. It is always enabled. + .. versionchanged:: 2.6 + *export_modules_as_functions* is now deprecated. + .. versionchanged:: 2.6 + *autograd_inlining* is now deprecated. + .. versionchanged:: 2.7 + *optimize* is now True by default. + Args: model: The model to be exported. args: Example positional inputs. Any non-Tensor arguments will be hard-coded into the @@ -342,6 +356,9 @@ def export( autograd_inlining: Deprecated. Flag used to control whether to inline autograd functions. Refer to https://github.com/pytorch/pytorch/pull/74765 for more details. + + Returns: + :class:`torch.onnx.ONNXProgram` if dynamo is True, otherwise None. """ if dynamo is True or isinstance(model, torch.export.ExportedProgram): from torch.onnx._internal.exporter import _compat @@ -402,6 +419,9 @@ def export( return None +@deprecated( + "torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead." +) def dynamo_export( model: torch.nn.Module | Callable | torch.export.ExportedProgram, # type: ignore[name-defined] /, @@ -411,6 +431,9 @@ def dynamo_export( ) -> ONNXProgram: """Export a torch.nn.Module to an ONNX graph. + .. deprecated:: 2.6 + Please use ``torch.onnx.export(..., dynamo=True)`` instead. + Args: model: The PyTorch model to be exported to ONNX. model_args: Positional inputs to ``model``. @@ -452,7 +475,6 @@ def dynamo_export( onnx_program.save("my_dynamic_model.onnx") """ - # NOTE: The new exporter is experimental and is not enabled by default. import warnings from torch.onnx import _flags diff --git a/torch/onnx/_deprecation.py b/torch/onnx/_deprecation.py deleted file mode 100644 index 61d9bb264fa..00000000000 --- a/torch/onnx/_deprecation.py +++ /dev/null @@ -1,72 +0,0 @@ -"""Utility for deprecating functions.""" - -import functools -import textwrap -import warnings -from typing import Callable, TypeVar -from typing_extensions import ParamSpec - - -_T = TypeVar("_T") -_P = ParamSpec("_P") - - -def deprecated( - since: str, removed_in: str, instructions: str -) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: - """Marks functions as deprecated. - - It will result in a warning when the function is called and a note in the - docstring. - - Args: - since: The version when the function was first deprecated. - removed_in: The version when the function will be removed. - instructions: The action users should take. - """ - - def decorator(function: Callable[_P, _T]) -> Callable[_P, _T]: - @functools.wraps(function) - def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: - warnings.warn( - f"'{function.__module__}.{function.__name__}' " - f"is deprecated in version {since} and will be " - f"removed in {removed_in}. Please {instructions}.", - category=DeprecationWarning, - stacklevel=2, - ) - return function(*args, **kwargs) - - # Add a deprecation note to the docstring. - docstring = function.__doc__ or "" - - # Add a note to the docstring. - deprecation_note = textwrap.dedent( - f"""\ - .. deprecated:: {since} - Deprecated and will be removed in version {removed_in}. - Please {instructions}. - """ - ) - - # Split docstring at first occurrence of newline - summary_and_body = docstring.split("\n\n", 1) - - if len(summary_and_body) > 1: - summary, body = summary_and_body - - # Dedent the body. We cannot do this with the presence of the summary because - # the body contains leading whitespaces when the summary does not. - body = textwrap.dedent(body) - - new_docstring_parts = [deprecation_note, "\n\n", summary, body] - else: - summary = summary_and_body[0] - - new_docstring_parts = [deprecation_note, "\n\n", summary] - - wrapper.__doc__ = "".join(new_docstring_parts) - - return wrapper - - return decorator diff --git a/torch/onnx/_exporter_states.py b/torch/onnx/_exporter_states.py deleted file mode 100644 index 2fdf7a7ac95..00000000000 --- a/torch/onnx/_exporter_states.py +++ /dev/null @@ -1,12 +0,0 @@ -from __future__ import annotations - - -class ExportTypes: - """Specifies how the ONNX model is stored.""" - - # TODO(justinchuby): Deprecate and remove this class. - - PROTOBUF_FILE = "Saves model in the specified protobuf file." - ZIP_ARCHIVE = "Saves model in the specified ZIP file (uncompressed)." - COMPRESSED_ZIP_ARCHIVE = "Saves model in the specified ZIP file (compressed)." - DIRECTORY = "Saves model in the specified folder." diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py index e0c9b099574..ad196630f9d 100644 --- a/torch/onnx/_internal/_exporter_legacy.py +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -81,12 +81,14 @@ class ONNXFakeContext: @deprecated( - "torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead.", - category=DeprecationWarning, + "torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead." ) class OnnxRegistry: """Registry for ONNX functions. + .. deprecated:: 2.6 + Please use ``torch.onnx.export(..., dynamo=True)`` instead. + The registry maintains a mapping from qualified names to symbolic functions under a fixed opset version. It supports registering custom onnx-script functions and for dispatcher to dispatch calls to the appropriate function. @@ -229,12 +231,14 @@ class OnnxRegistry: @deprecated( - "torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead.", - category=DeprecationWarning, + "torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead." ) class ExportOptions: """Options to influence the TorchDynamo ONNX exporter. + .. deprecated:: 2.6 + Please use ``torch.onnx.export(..., dynamo=True)`` instead. + Attributes: dynamic_shapes: Shape information hint for input/output tensors. When ``None``, the exporter determines the most compatible setting. @@ -385,8 +389,9 @@ def enable_fake_mode(): It is highly recommended to initialize the model in fake mode when exporting models that are too large to fit into memory. - NOTE: This function does not support torch.onnx.export(..., dynamo=True, optimize=True), so - please call ONNXProgram.optimize() outside of the function after the model is exported. + .. note:: + This function does not support torch.onnx.export(..., dynamo=True, optimize=True). + Please call ONNXProgram.optimize() outside of the function after the model is exported. Example:: @@ -443,12 +448,14 @@ def enable_fake_mode(): @deprecated( - "torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead.", - category=DeprecationWarning, + "torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead." ) class ONNXRuntimeOptions: """Options to influence the execution of the ONNX model through ONNX Runtime. + .. deprecated:: 2.6 + Please use ``torch.onnx.export(..., dynamo=True)`` instead. + Attributes: session_options: ONNX Runtime session options. execution_providers: ONNX Runtime execution providers to use during model execution. @@ -701,8 +708,7 @@ def _assert_dependencies(export_options: ResolvedExportOptions): @deprecated( - "torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead.", - category=DeprecationWarning, + "torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead." ) def dynamo_export( model: torch.nn.Module | Callable, @@ -713,6 +719,9 @@ def dynamo_export( ) -> _onnx_program.ONNXProgram: """Export a torch.nn.Module to an ONNX graph. + .. deprecated:: 2.6 + Please use ``torch.onnx.export(..., dynamo=True)`` instead. + Args: model: The PyTorch model to be exported to ONNX. model_args: Positional inputs to ``model``. diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index e213451028d..c73b5a1b23a 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -15,6 +15,7 @@ import math import sys import warnings from typing import Callable, TYPE_CHECKING +from typing_extensions import deprecated import torch import torch._C._onnx as _C_onnx @@ -23,7 +24,7 @@ import torch.onnx from torch import _C # Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics -from torch.onnx import _constants, _deprecation, _type_utils, errors, symbolic_helper +from torch.onnx import _constants, _type_utils, errors, symbolic_helper from torch.onnx._globals import GLOBALS from torch.onnx._internal import jit_utils, registration @@ -3315,91 +3316,55 @@ def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_co @_onnx_symbolic("aten::_cast_Byte") -@_deprecation.deprecated( - "2.0", - "the future", - "Avoid using this function and create a Cast node instead", -) +@deprecated("Avoid using this function and create a Cast node instead") def _cast_Byte(g: jit_utils.GraphContext, input, non_blocking): return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.UINT8) @_onnx_symbolic("aten::_cast_Char") -@_deprecation.deprecated( - "2.0", - "the future", - "Avoid using this function and create a Cast node instead", -) +@deprecated("Avoid using this function and create a Cast node instead") def _cast_Char(g: jit_utils.GraphContext, input, non_blocking): return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT8) @_onnx_symbolic("aten::_cast_Short") -@_deprecation.deprecated( - "2.0", - "the future", - "Avoid using this function and create a Cast node instead", -) +@deprecated("Avoid using this function and create a Cast node instead") def _cast_Short(g: jit_utils.GraphContext, input, non_blocking): return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT16) @_onnx_symbolic("aten::_cast_Int") -@_deprecation.deprecated( - "2.0", - "the future", - "Avoid using this function and create a Cast node instead", -) +@deprecated("Avoid using this function and create a Cast node instead") def _cast_Int(g: jit_utils.GraphContext, input, non_blocking): return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) @_onnx_symbolic("aten::_cast_Long") -@_deprecation.deprecated( - "2.0", - "the future", - "Avoid using this function and create a Cast node instead", -) +@deprecated("Avoid using this function and create a Cast node instead") def _cast_Long(g: jit_utils.GraphContext, input, non_blocking): return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) @_onnx_symbolic("aten::_cast_Half") -@_deprecation.deprecated( - "2.0", - "the future", - "Avoid using this function and create a Cast node instead", -) +@deprecated("Avoid using this function and create a Cast node instead") def _cast_Half(g: jit_utils.GraphContext, input, non_blocking): return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT16) @_onnx_symbolic("aten::_cast_Float") -@_deprecation.deprecated( - "2.0", - "the future", - "Avoid using this function and create a Cast node instead", -) +@deprecated("Avoid using this function and create a Cast node instead") def _cast_Float(g: jit_utils.GraphContext, input, non_blocking): return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) @_onnx_symbolic("aten::_cast_Double") -@_deprecation.deprecated( - "2.0", - "the future", - "Avoid using this function and create a Cast node instead", -) +@deprecated("Avoid using this function and create a Cast node instead") def _cast_Double(g: jit_utils.GraphContext, input, non_blocking): return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE) @_onnx_symbolic("aten::_cast_Bool") -@_deprecation.deprecated( - "2.0", - "the future", - "Avoid using this function and create a Cast node instead", -) +@deprecated("Avoid using this function and create a Cast node instead") def _cast_Bool(g: jit_utils.GraphContext, input, non_blocking): return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL) diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 39d33cb93b6..87ef237bb0f 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -14,13 +14,14 @@ import re import typing import warnings from typing import Any, Callable, cast +from typing_extensions import deprecated import torch import torch._C._onnx as _C_onnx import torch.jit._trace import torch.serialization from torch import _C -from torch.onnx import _constants, _deprecation, errors, symbolic_helper # noqa: F401 +from torch.onnx import _constants, errors, symbolic_helper # noqa: F401 from torch.onnx._globals import GLOBALS from torch.onnx._internal import diagnostics, jit_utils, onnx_proto_utils, registration @@ -55,11 +56,15 @@ def is_in_onnx_export() -> bool: _params_dict = {} # type: ignore[var-annotated] +@deprecated("Please set training mode before exporting the model", category=None) @contextlib.contextmanager def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode): - r"""A context manager to temporarily set the training mode of ``model`` + """A context manager to temporarily set the training mode of ``model`` to ``mode``, resetting it when we exit the with-block. + .. deprecated:: 2.7 + Please set training mode before exporting the model. + Args: model: Same type and meaning as ``model`` arg to :func:`export`. mode: Same type and meaning as ``training`` arg to :func:`export`. @@ -103,8 +108,14 @@ def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode): model.train(originally_training) +@deprecated("Please remove usage of this function", category=None) @contextlib.contextmanager def disable_apex_o2_state_dict_hook(model: torch.nn.Module | torch.jit.ScriptFunction): + """A context manager to temporarily disable the Apex O2 hook that returns. + + .. deprecated:: 2.7 + Please remove usage of this function. + """ # Apex O2 hook state_dict to return fp16 weights as fp32. # Exporter cannot identify them as same tensors. # Since this hook is only used by optimizer, it is safe to @@ -134,8 +145,14 @@ def disable_apex_o2_state_dict_hook(model: torch.nn.Module | torch.jit.ScriptFun pass +@deprecated("Please remove usage of this function") @contextlib.contextmanager def setup_onnx_logging(verbose: bool): + """A context manager to temporarily set the ONNX logging verbosity. + + .. deprecated:: 2.7 + Please remove usage of this function. + """ is_originally_enabled = _C._jit_is_onnx_log_enabled if is_originally_enabled or verbose: # type: ignore[truthy-function] _C._jit_set_onnx_log_enabled(True) @@ -146,8 +163,15 @@ def setup_onnx_logging(verbose: bool): _C._jit_set_onnx_log_enabled(False) +@deprecated("Please remove usage of this function", category=None) @contextlib.contextmanager def exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool): + """A context manager to temporarily set the training mode of ``model`` + to ``mode``, disable the Apex O2 hook, and set the ONNX logging verbosity. + + .. deprecated:: 2.7 + Please set training mode before exporting the model. + """ with select_model_mode_for_export( model, mode ) as mode_ctx, disable_apex_o2_state_dict_hook( @@ -1153,7 +1177,7 @@ def _model_to_graph( return graph, params_dict, torch_out -@_deprecation.deprecated("2.5", "the future", "avoid using this function") +@deprecated("Please remove usage of this function") def unconvertible_ops( model, args, @@ -1162,6 +1186,9 @@ def unconvertible_ops( ) -> tuple[_C.Graph, list[str]]: """Returns an approximated list of all ops that are yet supported by :mod:`torch.onnx`. + .. deprecated:: 2.5 + Please remove usage of this function. + The list is approximated because some ops may be removed during the conversion process and don't need to be converted. Some other ops may have partial support that will fail conversion with particular inputs. Please open a Github Issue diff --git a/torch/onnx/verification.py b/torch/onnx/verification.py index a4277d9af37..049195a4bf1 100644 --- a/torch/onnx/verification.py +++ b/torch/onnx/verification.py @@ -44,7 +44,12 @@ _OutputsType = Union[Sequence[_NumericType], Sequence] class OnnxBackend(enum.Enum): - """Enum class for ONNX backend used for export verification.""" + """Enum class for ONNX backend used for export verification. + + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + """ REFERENCE = "ONNXReferenceEvaluator" ONNX_RUNTIME_CPU = "CPUExecutionProvider" @@ -55,6 +60,10 @@ class OnnxBackend(enum.Enum): class VerificationOptions: """Options for ONNX export verification. + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + Attributes: flatten: If True, unpack nested list/tuple/dict inputs into a flattened list of Tensors for ONNX. Set this to False if nested structures are to be preserved @@ -775,7 +784,7 @@ def check_export_model_diff( @typing_extensions.deprecated( "torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) " "and use ONNXProgram to test the ONNX model", - category=DeprecationWarning, + category=None, ) def verify( model: _ModelType, @@ -797,6 +806,10 @@ def verify( ): """Verify model export to ONNX against original PyTorch model. + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + Args: model (torch.nn.Module or torch.jit.ScriptModule): See :func:`torch.onnx.export`. input_args (tuple): See :func:`torch.onnx.export`. @@ -866,8 +879,7 @@ def verify( @typing_extensions.deprecated( "torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) " - "and use ONNXProgram to test the ONNX model", - category=DeprecationWarning, + "and use ONNXProgram to test the ONNX model" ) def verify_aten_graph( graph: torch.Graph, @@ -876,6 +888,12 @@ def verify_aten_graph( params_dict: dict[str, Any] | None = None, verification_options: VerificationOptions | None = None, ) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]: + """Verify aten graph export to ONNX against original PyTorch model. + + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + """ if verification_options is None: verification_options = VerificationOptions() if params_dict is None: @@ -1161,12 +1179,16 @@ class OnnxTestCaseRepro: @typing_extensions.deprecated( "torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) " - "and use ONNXProgram to test the ONNX model", - category=DeprecationWarning, + "and use ONNXProgram to test the ONNX model" ) @dataclasses.dataclass class GraphInfo: - """GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph.""" + """GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph. + + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + """ graph: torch.Graph input_args: tuple[Any, ...] @@ -1691,6 +1713,10 @@ def _produced_by(value: torch.Value, nodes: Collection[torch.Node]) -> bool: return value.node() in nodes +@typing_extensions.deprecated( + "torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) " + "and use ONNXProgram to test the ONNX model" +) def find_mismatch( model: torch.nn.Module | torch.jit.ScriptModule, input_args: tuple[Any, ...], @@ -1703,6 +1729,10 @@ def find_mismatch( ) -> GraphInfo: r"""Find all mismatches between the original model and the exported model. + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + Experimental. The API is subject to change. This tool helps debug the mismatch between the original PyTorch model and exported