mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[ONNX] Adjust and add deprecation messages (#146639)
Adjust and add deprecation messages to torch.onnx utilities and verification methods because they are only related to torch script and are obsolete. Removed unused `_exporter_states.py` and removed the internal deprecation module in favor of the typing_extensions deprecated decorator. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146639 Approved by: https://github.com/titaiwangms
This commit is contained in:
parent
2328dcccb9
commit
63c2909ae3
7 changed files with 120 additions and 151 deletions
|
|
@ -49,6 +49,7 @@ __all__ = [
|
|||
]
|
||||
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
|
|
@ -168,6 +169,19 @@ def export(
|
|||
) -> ONNXProgram | None:
|
||||
r"""Exports a model into ONNX format.
|
||||
|
||||
.. versionchanged:: 2.6
|
||||
*training* is now deprecated. Instead, set the training mode of the model before exporting.
|
||||
.. versionchanged:: 2.6
|
||||
*operator_export_type* is now deprecated. Only ONNX is supported.
|
||||
.. versionchanged:: 2.6
|
||||
*do_constant_folding* is now deprecated. It is always enabled.
|
||||
.. versionchanged:: 2.6
|
||||
*export_modules_as_functions* is now deprecated.
|
||||
.. versionchanged:: 2.6
|
||||
*autograd_inlining* is now deprecated.
|
||||
.. versionchanged:: 2.7
|
||||
*optimize* is now True by default.
|
||||
|
||||
Args:
|
||||
model: The model to be exported.
|
||||
args: Example positional inputs. Any non-Tensor arguments will be hard-coded into the
|
||||
|
|
@ -342,6 +356,9 @@ def export(
|
|||
autograd_inlining: Deprecated.
|
||||
Flag used to control whether to inline autograd functions.
|
||||
Refer to https://github.com/pytorch/pytorch/pull/74765 for more details.
|
||||
|
||||
Returns:
|
||||
:class:`torch.onnx.ONNXProgram` if dynamo is True, otherwise None.
|
||||
"""
|
||||
if dynamo is True or isinstance(model, torch.export.ExportedProgram):
|
||||
from torch.onnx._internal.exporter import _compat
|
||||
|
|
@ -402,6 +419,9 @@ def export(
|
|||
return None
|
||||
|
||||
|
||||
@deprecated(
|
||||
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead."
|
||||
)
|
||||
def dynamo_export(
|
||||
model: torch.nn.Module | Callable | torch.export.ExportedProgram, # type: ignore[name-defined]
|
||||
/,
|
||||
|
|
@ -411,6 +431,9 @@ def dynamo_export(
|
|||
) -> ONNXProgram:
|
||||
"""Export a torch.nn.Module to an ONNX graph.
|
||||
|
||||
.. deprecated:: 2.6
|
||||
Please use ``torch.onnx.export(..., dynamo=True)`` instead.
|
||||
|
||||
Args:
|
||||
model: The PyTorch model to be exported to ONNX.
|
||||
model_args: Positional inputs to ``model``.
|
||||
|
|
@ -452,7 +475,6 @@ def dynamo_export(
|
|||
onnx_program.save("my_dynamic_model.onnx")
|
||||
"""
|
||||
|
||||
# NOTE: The new exporter is experimental and is not enabled by default.
|
||||
import warnings
|
||||
|
||||
from torch.onnx import _flags
|
||||
|
|
|
|||
|
|
@ -1,72 +0,0 @@
|
|||
"""Utility for deprecating functions."""
|
||||
|
||||
import functools
|
||||
import textwrap
|
||||
import warnings
|
||||
from typing import Callable, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
def deprecated(
|
||||
since: str, removed_in: str, instructions: str
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
"""Marks functions as deprecated.
|
||||
|
||||
It will result in a warning when the function is called and a note in the
|
||||
docstring.
|
||||
|
||||
Args:
|
||||
since: The version when the function was first deprecated.
|
||||
removed_in: The version when the function will be removed.
|
||||
instructions: The action users should take.
|
||||
"""
|
||||
|
||||
def decorator(function: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
@functools.wraps(function)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||
warnings.warn(
|
||||
f"'{function.__module__}.{function.__name__}' "
|
||||
f"is deprecated in version {since} and will be "
|
||||
f"removed in {removed_in}. Please {instructions}.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return function(*args, **kwargs)
|
||||
|
||||
# Add a deprecation note to the docstring.
|
||||
docstring = function.__doc__ or ""
|
||||
|
||||
# Add a note to the docstring.
|
||||
deprecation_note = textwrap.dedent(
|
||||
f"""\
|
||||
.. deprecated:: {since}
|
||||
Deprecated and will be removed in version {removed_in}.
|
||||
Please {instructions}.
|
||||
"""
|
||||
)
|
||||
|
||||
# Split docstring at first occurrence of newline
|
||||
summary_and_body = docstring.split("\n\n", 1)
|
||||
|
||||
if len(summary_and_body) > 1:
|
||||
summary, body = summary_and_body
|
||||
|
||||
# Dedent the body. We cannot do this with the presence of the summary because
|
||||
# the body contains leading whitespaces when the summary does not.
|
||||
body = textwrap.dedent(body)
|
||||
|
||||
new_docstring_parts = [deprecation_note, "\n\n", summary, body]
|
||||
else:
|
||||
summary = summary_and_body[0]
|
||||
|
||||
new_docstring_parts = [deprecation_note, "\n\n", summary]
|
||||
|
||||
wrapper.__doc__ = "".join(new_docstring_parts)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
|
||||
class ExportTypes:
|
||||
"""Specifies how the ONNX model is stored."""
|
||||
|
||||
# TODO(justinchuby): Deprecate and remove this class.
|
||||
|
||||
PROTOBUF_FILE = "Saves model in the specified protobuf file."
|
||||
ZIP_ARCHIVE = "Saves model in the specified ZIP file (uncompressed)."
|
||||
COMPRESSED_ZIP_ARCHIVE = "Saves model in the specified ZIP file (compressed)."
|
||||
DIRECTORY = "Saves model in the specified folder."
|
||||
|
|
@ -81,12 +81,14 @@ class ONNXFakeContext:
|
|||
|
||||
|
||||
@deprecated(
|
||||
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead.",
|
||||
category=DeprecationWarning,
|
||||
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead."
|
||||
)
|
||||
class OnnxRegistry:
|
||||
"""Registry for ONNX functions.
|
||||
|
||||
.. deprecated:: 2.6
|
||||
Please use ``torch.onnx.export(..., dynamo=True)`` instead.
|
||||
|
||||
The registry maintains a mapping from qualified names to symbolic functions under a
|
||||
fixed opset version. It supports registering custom onnx-script functions and for
|
||||
dispatcher to dispatch calls to the appropriate function.
|
||||
|
|
@ -229,12 +231,14 @@ class OnnxRegistry:
|
|||
|
||||
|
||||
@deprecated(
|
||||
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead.",
|
||||
category=DeprecationWarning,
|
||||
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead."
|
||||
)
|
||||
class ExportOptions:
|
||||
"""Options to influence the TorchDynamo ONNX exporter.
|
||||
|
||||
.. deprecated:: 2.6
|
||||
Please use ``torch.onnx.export(..., dynamo=True)`` instead.
|
||||
|
||||
Attributes:
|
||||
dynamic_shapes: Shape information hint for input/output tensors.
|
||||
When ``None``, the exporter determines the most compatible setting.
|
||||
|
|
@ -385,8 +389,9 @@ def enable_fake_mode():
|
|||
It is highly recommended to initialize the model in fake mode when exporting models that
|
||||
are too large to fit into memory.
|
||||
|
||||
NOTE: This function does not support torch.onnx.export(..., dynamo=True, optimize=True), so
|
||||
please call ONNXProgram.optimize() outside of the function after the model is exported.
|
||||
.. note::
|
||||
This function does not support torch.onnx.export(..., dynamo=True, optimize=True).
|
||||
Please call ONNXProgram.optimize() outside of the function after the model is exported.
|
||||
|
||||
Example::
|
||||
|
||||
|
|
@ -443,12 +448,14 @@ def enable_fake_mode():
|
|||
|
||||
|
||||
@deprecated(
|
||||
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead.",
|
||||
category=DeprecationWarning,
|
||||
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead."
|
||||
)
|
||||
class ONNXRuntimeOptions:
|
||||
"""Options to influence the execution of the ONNX model through ONNX Runtime.
|
||||
|
||||
.. deprecated:: 2.6
|
||||
Please use ``torch.onnx.export(..., dynamo=True)`` instead.
|
||||
|
||||
Attributes:
|
||||
session_options: ONNX Runtime session options.
|
||||
execution_providers: ONNX Runtime execution providers to use during model execution.
|
||||
|
|
@ -701,8 +708,7 @@ def _assert_dependencies(export_options: ResolvedExportOptions):
|
|||
|
||||
|
||||
@deprecated(
|
||||
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead.",
|
||||
category=DeprecationWarning,
|
||||
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead."
|
||||
)
|
||||
def dynamo_export(
|
||||
model: torch.nn.Module | Callable,
|
||||
|
|
@ -713,6 +719,9 @@ def dynamo_export(
|
|||
) -> _onnx_program.ONNXProgram:
|
||||
"""Export a torch.nn.Module to an ONNX graph.
|
||||
|
||||
.. deprecated:: 2.6
|
||||
Please use ``torch.onnx.export(..., dynamo=True)`` instead.
|
||||
|
||||
Args:
|
||||
model: The PyTorch model to be exported to ONNX.
|
||||
model_args: Positional inputs to ``model``.
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ import math
|
|||
import sys
|
||||
import warnings
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
import torch._C._onnx as _C_onnx
|
||||
|
|
@ -23,7 +24,7 @@ import torch.onnx
|
|||
from torch import _C
|
||||
|
||||
# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
|
||||
from torch.onnx import _constants, _deprecation, _type_utils, errors, symbolic_helper
|
||||
from torch.onnx import _constants, _type_utils, errors, symbolic_helper
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
|
||||
|
|
@ -3315,91 +3316,55 @@ def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_co
|
|||
|
||||
|
||||
@_onnx_symbolic("aten::_cast_Byte")
|
||||
@_deprecation.deprecated(
|
||||
"2.0",
|
||||
"the future",
|
||||
"Avoid using this function and create a Cast node instead",
|
||||
)
|
||||
@deprecated("Avoid using this function and create a Cast node instead")
|
||||
def _cast_Byte(g: jit_utils.GraphContext, input, non_blocking):
|
||||
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.UINT8)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_cast_Char")
|
||||
@_deprecation.deprecated(
|
||||
"2.0",
|
||||
"the future",
|
||||
"Avoid using this function and create a Cast node instead",
|
||||
)
|
||||
@deprecated("Avoid using this function and create a Cast node instead")
|
||||
def _cast_Char(g: jit_utils.GraphContext, input, non_blocking):
|
||||
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT8)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_cast_Short")
|
||||
@_deprecation.deprecated(
|
||||
"2.0",
|
||||
"the future",
|
||||
"Avoid using this function and create a Cast node instead",
|
||||
)
|
||||
@deprecated("Avoid using this function and create a Cast node instead")
|
||||
def _cast_Short(g: jit_utils.GraphContext, input, non_blocking):
|
||||
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT16)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_cast_Int")
|
||||
@_deprecation.deprecated(
|
||||
"2.0",
|
||||
"the future",
|
||||
"Avoid using this function and create a Cast node instead",
|
||||
)
|
||||
@deprecated("Avoid using this function and create a Cast node instead")
|
||||
def _cast_Int(g: jit_utils.GraphContext, input, non_blocking):
|
||||
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_cast_Long")
|
||||
@_deprecation.deprecated(
|
||||
"2.0",
|
||||
"the future",
|
||||
"Avoid using this function and create a Cast node instead",
|
||||
)
|
||||
@deprecated("Avoid using this function and create a Cast node instead")
|
||||
def _cast_Long(g: jit_utils.GraphContext, input, non_blocking):
|
||||
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_cast_Half")
|
||||
@_deprecation.deprecated(
|
||||
"2.0",
|
||||
"the future",
|
||||
"Avoid using this function and create a Cast node instead",
|
||||
)
|
||||
@deprecated("Avoid using this function and create a Cast node instead")
|
||||
def _cast_Half(g: jit_utils.GraphContext, input, non_blocking):
|
||||
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_cast_Float")
|
||||
@_deprecation.deprecated(
|
||||
"2.0",
|
||||
"the future",
|
||||
"Avoid using this function and create a Cast node instead",
|
||||
)
|
||||
@deprecated("Avoid using this function and create a Cast node instead")
|
||||
def _cast_Float(g: jit_utils.GraphContext, input, non_blocking):
|
||||
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_cast_Double")
|
||||
@_deprecation.deprecated(
|
||||
"2.0",
|
||||
"the future",
|
||||
"Avoid using this function and create a Cast node instead",
|
||||
)
|
||||
@deprecated("Avoid using this function and create a Cast node instead")
|
||||
def _cast_Double(g: jit_utils.GraphContext, input, non_blocking):
|
||||
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_cast_Bool")
|
||||
@_deprecation.deprecated(
|
||||
"2.0",
|
||||
"the future",
|
||||
"Avoid using this function and create a Cast node instead",
|
||||
)
|
||||
@deprecated("Avoid using this function and create a Cast node instead")
|
||||
def _cast_Bool(g: jit_utils.GraphContext, input, non_blocking):
|
||||
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL)
|
||||
|
||||
|
|
|
|||
|
|
@ -14,13 +14,14 @@ import re
|
|||
import typing
|
||||
import warnings
|
||||
from typing import Any, Callable, cast
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
import torch._C._onnx as _C_onnx
|
||||
import torch.jit._trace
|
||||
import torch.serialization
|
||||
from torch import _C
|
||||
from torch.onnx import _constants, _deprecation, errors, symbolic_helper # noqa: F401
|
||||
from torch.onnx import _constants, errors, symbolic_helper # noqa: F401
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import diagnostics, jit_utils, onnx_proto_utils, registration
|
||||
|
||||
|
|
@ -55,11 +56,15 @@ def is_in_onnx_export() -> bool:
|
|||
_params_dict = {} # type: ignore[var-annotated]
|
||||
|
||||
|
||||
@deprecated("Please set training mode before exporting the model", category=None)
|
||||
@contextlib.contextmanager
|
||||
def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode):
|
||||
r"""A context manager to temporarily set the training mode of ``model``
|
||||
"""A context manager to temporarily set the training mode of ``model``
|
||||
to ``mode``, resetting it when we exit the with-block.
|
||||
|
||||
.. deprecated:: 2.7
|
||||
Please set training mode before exporting the model.
|
||||
|
||||
Args:
|
||||
model: Same type and meaning as ``model`` arg to :func:`export`.
|
||||
mode: Same type and meaning as ``training`` arg to :func:`export`.
|
||||
|
|
@ -103,8 +108,14 @@ def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode):
|
|||
model.train(originally_training)
|
||||
|
||||
|
||||
@deprecated("Please remove usage of this function", category=None)
|
||||
@contextlib.contextmanager
|
||||
def disable_apex_o2_state_dict_hook(model: torch.nn.Module | torch.jit.ScriptFunction):
|
||||
"""A context manager to temporarily disable the Apex O2 hook that returns.
|
||||
|
||||
.. deprecated:: 2.7
|
||||
Please remove usage of this function.
|
||||
"""
|
||||
# Apex O2 hook state_dict to return fp16 weights as fp32.
|
||||
# Exporter cannot identify them as same tensors.
|
||||
# Since this hook is only used by optimizer, it is safe to
|
||||
|
|
@ -134,8 +145,14 @@ def disable_apex_o2_state_dict_hook(model: torch.nn.Module | torch.jit.ScriptFun
|
|||
pass
|
||||
|
||||
|
||||
@deprecated("Please remove usage of this function")
|
||||
@contextlib.contextmanager
|
||||
def setup_onnx_logging(verbose: bool):
|
||||
"""A context manager to temporarily set the ONNX logging verbosity.
|
||||
|
||||
.. deprecated:: 2.7
|
||||
Please remove usage of this function.
|
||||
"""
|
||||
is_originally_enabled = _C._jit_is_onnx_log_enabled
|
||||
if is_originally_enabled or verbose: # type: ignore[truthy-function]
|
||||
_C._jit_set_onnx_log_enabled(True)
|
||||
|
|
@ -146,8 +163,15 @@ def setup_onnx_logging(verbose: bool):
|
|||
_C._jit_set_onnx_log_enabled(False)
|
||||
|
||||
|
||||
@deprecated("Please remove usage of this function", category=None)
|
||||
@contextlib.contextmanager
|
||||
def exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool):
|
||||
"""A context manager to temporarily set the training mode of ``model``
|
||||
to ``mode``, disable the Apex O2 hook, and set the ONNX logging verbosity.
|
||||
|
||||
.. deprecated:: 2.7
|
||||
Please set training mode before exporting the model.
|
||||
"""
|
||||
with select_model_mode_for_export(
|
||||
model, mode
|
||||
) as mode_ctx, disable_apex_o2_state_dict_hook(
|
||||
|
|
@ -1153,7 +1177,7 @@ def _model_to_graph(
|
|||
return graph, params_dict, torch_out
|
||||
|
||||
|
||||
@_deprecation.deprecated("2.5", "the future", "avoid using this function")
|
||||
@deprecated("Please remove usage of this function")
|
||||
def unconvertible_ops(
|
||||
model,
|
||||
args,
|
||||
|
|
@ -1162,6 +1186,9 @@ def unconvertible_ops(
|
|||
) -> tuple[_C.Graph, list[str]]:
|
||||
"""Returns an approximated list of all ops that are yet supported by :mod:`torch.onnx`.
|
||||
|
||||
.. deprecated:: 2.5
|
||||
Please remove usage of this function.
|
||||
|
||||
The list is approximated because some ops may be removed during the conversion
|
||||
process and don't need to be converted. Some other ops may have partial support
|
||||
that will fail conversion with particular inputs. Please open a Github Issue
|
||||
|
|
|
|||
|
|
@ -44,7 +44,12 @@ _OutputsType = Union[Sequence[_NumericType], Sequence]
|
|||
|
||||
|
||||
class OnnxBackend(enum.Enum):
|
||||
"""Enum class for ONNX backend used for export verification."""
|
||||
"""Enum class for ONNX backend used for export verification.
|
||||
|
||||
.. deprecated:: 2.7
|
||||
Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned
|
||||
``ONNXProgram`` to test the ONNX model.
|
||||
"""
|
||||
|
||||
REFERENCE = "ONNXReferenceEvaluator"
|
||||
ONNX_RUNTIME_CPU = "CPUExecutionProvider"
|
||||
|
|
@ -55,6 +60,10 @@ class OnnxBackend(enum.Enum):
|
|||
class VerificationOptions:
|
||||
"""Options for ONNX export verification.
|
||||
|
||||
.. deprecated:: 2.7
|
||||
Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned
|
||||
``ONNXProgram`` to test the ONNX model.
|
||||
|
||||
Attributes:
|
||||
flatten: If True, unpack nested list/tuple/dict inputs into a flattened list of
|
||||
Tensors for ONNX. Set this to False if nested structures are to be preserved
|
||||
|
|
@ -775,7 +784,7 @@ def check_export_model_diff(
|
|||
@typing_extensions.deprecated(
|
||||
"torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) "
|
||||
"and use ONNXProgram to test the ONNX model",
|
||||
category=DeprecationWarning,
|
||||
category=None,
|
||||
)
|
||||
def verify(
|
||||
model: _ModelType,
|
||||
|
|
@ -797,6 +806,10 @@ def verify(
|
|||
):
|
||||
"""Verify model export to ONNX against original PyTorch model.
|
||||
|
||||
.. deprecated:: 2.7
|
||||
Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned
|
||||
``ONNXProgram`` to test the ONNX model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module or torch.jit.ScriptModule): See :func:`torch.onnx.export`.
|
||||
input_args (tuple): See :func:`torch.onnx.export`.
|
||||
|
|
@ -866,8 +879,7 @@ def verify(
|
|||
|
||||
@typing_extensions.deprecated(
|
||||
"torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) "
|
||||
"and use ONNXProgram to test the ONNX model",
|
||||
category=DeprecationWarning,
|
||||
"and use ONNXProgram to test the ONNX model"
|
||||
)
|
||||
def verify_aten_graph(
|
||||
graph: torch.Graph,
|
||||
|
|
@ -876,6 +888,12 @@ def verify_aten_graph(
|
|||
params_dict: dict[str, Any] | None = None,
|
||||
verification_options: VerificationOptions | None = None,
|
||||
) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]:
|
||||
"""Verify aten graph export to ONNX against original PyTorch model.
|
||||
|
||||
.. deprecated:: 2.7
|
||||
Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned
|
||||
``ONNXProgram`` to test the ONNX model.
|
||||
"""
|
||||
if verification_options is None:
|
||||
verification_options = VerificationOptions()
|
||||
if params_dict is None:
|
||||
|
|
@ -1161,12 +1179,16 @@ class OnnxTestCaseRepro:
|
|||
|
||||
@typing_extensions.deprecated(
|
||||
"torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) "
|
||||
"and use ONNXProgram to test the ONNX model",
|
||||
category=DeprecationWarning,
|
||||
"and use ONNXProgram to test the ONNX model"
|
||||
)
|
||||
@dataclasses.dataclass
|
||||
class GraphInfo:
|
||||
"""GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph."""
|
||||
"""GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph.
|
||||
|
||||
.. deprecated:: 2.7
|
||||
Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned
|
||||
``ONNXProgram`` to test the ONNX model.
|
||||
"""
|
||||
|
||||
graph: torch.Graph
|
||||
input_args: tuple[Any, ...]
|
||||
|
|
@ -1691,6 +1713,10 @@ def _produced_by(value: torch.Value, nodes: Collection[torch.Node]) -> bool:
|
|||
return value.node() in nodes
|
||||
|
||||
|
||||
@typing_extensions.deprecated(
|
||||
"torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) "
|
||||
"and use ONNXProgram to test the ONNX model"
|
||||
)
|
||||
def find_mismatch(
|
||||
model: torch.nn.Module | torch.jit.ScriptModule,
|
||||
input_args: tuple[Any, ...],
|
||||
|
|
@ -1703,6 +1729,10 @@ def find_mismatch(
|
|||
) -> GraphInfo:
|
||||
r"""Find all mismatches between the original model and the exported model.
|
||||
|
||||
.. deprecated:: 2.7
|
||||
Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned
|
||||
``ONNXProgram`` to test the ONNX model.
|
||||
|
||||
Experimental. The API is subject to change.
|
||||
|
||||
This tool helps debug the mismatch between the original PyTorch model and exported
|
||||
|
|
|
|||
Loading…
Reference in a new issue