From d72cd4c4e59c388b929923b701d5f8b93c5e8572 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 1 Dec 2022 09:28:02 +0100 Subject: [PATCH] document torch.testing.assert_allclose (#89526) After our failed attempt to remove `assert_allclose` in #87974, we decided to add it to the documentation after all. Although we drop the expected removal date, the function continues to be deprecated in favor of `assert_close`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89526 Approved by: https://github.com/mruberry --- docs/source/testing.rst | 1 + torch/testing/__init__.py | 3 +- torch/testing/_comparison.py | 60 ++++++++++++++++++++++++- torch/testing/_deprecated.py | 85 ------------------------------------ 4 files changed, 60 insertions(+), 89 deletions(-) delete mode 100644 torch/testing/_deprecated.py diff --git a/docs/source/testing.rst b/docs/source/testing.rst index 122aa651b95..8837c4a0ec1 100644 --- a/docs/source/testing.rst +++ b/docs/source/testing.rst @@ -6,3 +6,4 @@ torch.testing .. autofunction:: assert_close .. autofunction:: make_tensor +.. autofunction:: assert_allclose diff --git a/torch/testing/__init__.py b/torch/testing/__init__.py index d437ed9e972..58b8f828e35 100644 --- a/torch/testing/__init__.py +++ b/torch/testing/__init__.py @@ -1,4 +1,3 @@ from torch._C import FileCheck as FileCheck -from ._comparison import assert_close as assert_close +from ._comparison import assert_allclose, assert_close as assert_close from ._creation import make_tensor as make_tensor -from ._deprecated import * # noqa: F403 diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py index 3cc729457cb..71824c9815f 100644 --- a/torch/testing/_comparison.py +++ b/torch/testing/_comparison.py @@ -2,10 +2,12 @@ import abc import cmath import collections.abc import contextlib +import warnings from typing import ( Any, Callable, Collection, + Dict, List, NoReturn, Optional, @@ -83,7 +85,8 @@ _DTYPE_PRECISIONS.update( def default_tolerances( - *inputs: Union[torch.Tensor, torch.dtype] + *inputs: Union[torch.Tensor, torch.dtype], + dtype_precisions: Optional[Dict[torch.dtype, Tuple[float, float]]] = None, ) -> Tuple[float, float]: """Returns the default absolute and relative testing tolerances for a set of inputs based on the dtype. @@ -102,7 +105,8 @@ def default_tolerances( raise TypeError( f"Expected a torch.Tensor or a torch.dtype, but got {type(input)} instead." ) - rtols, atols = zip(*[_DTYPE_PRECISIONS.get(dtype, (0.0, 0.0)) for dtype in dtypes]) + dtype_precisions = dtype_precisions or _DTYPE_PRECISIONS + rtols, atols = zip(*[dtype_precisions.get(dtype, (0.0, 0.0)) for dtype in dtypes]) return max(rtols), max(atols) @@ -1531,3 +1535,55 @@ def assert_close( check_stride=check_stride, msg=msg, ) + + +def assert_allclose( + actual: Any, + expected: Any, + rtol: Optional[float] = None, + atol: Optional[float] = None, + equal_nan: bool = True, + msg: str = "", +) -> None: + """ + .. warning:: + + :func:`torch.testing.assert_allclose` is deprecated since ``1.12`` and will be removed in a future release. + Please use :func:`torch.testing.assert_close` instead. You can find detailed upgrade instructions + `here `_. + """ + warnings.warn( + "`torch.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. " + "Please use `torch.testing.assert_close()` instead. " + "You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.", + FutureWarning, + stacklevel=2, + ) + + if not isinstance(actual, torch.Tensor): + actual = torch.tensor(actual) + if not isinstance(expected, torch.Tensor): + expected = torch.tensor(expected, dtype=actual.dtype) + + if rtol is None and atol is None: + rtol, atol = default_tolerances( + actual, + expected, + dtype_precisions={ + torch.float16: (1e-3, 1e-3), + torch.float32: (1e-4, 1e-5), + torch.float64: (1e-5, 1e-8), + }, + ) + + torch.testing.assert_close( + actual, + expected, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + check_device=True, + check_dtype=False, + check_stride=False, + msg=msg or None, + ) diff --git a/torch/testing/_deprecated.py b/torch/testing/_deprecated.py deleted file mode 100644 index a9ef0c58cb9..00000000000 --- a/torch/testing/_deprecated.py +++ /dev/null @@ -1,85 +0,0 @@ -"""This module exists since the `torch.testing` exposed a lot of stuff that shouldn't have been public. Although this -was never documented anywhere, some other internal FB projects as well as downstream OSS projects might use this. Thus, -we don't internalize without warning, but still go through a deprecation cycle. -""" - -import functools -import warnings -from typing import Any, Callable, Dict, Optional, Tuple, Union - -import torch - - -__all__ = ["assert_allclose"] - - -def warn_deprecated( - instructions: Union[str, Callable[[str, Tuple[Any, ...], Dict[str, Any], Any], str]] -) -> Callable: - def outer_wrapper(fn: Callable) -> Callable: - name = fn.__name__ - head = f"torch.testing.{name}() is deprecated since 1.12 and will be removed in 1.14. " - - @functools.wraps(fn) - def inner_wrapper(*args: Any, **kwargs: Any) -> Any: - return_value = fn(*args, **kwargs) - tail = ( - instructions(name, args, kwargs, return_value) - if callable(instructions) - else instructions - ) - msg = (head + tail).strip() - warnings.warn(msg, FutureWarning) - return return_value - - return inner_wrapper - - return outer_wrapper - - -_DTYPE_PRECISIONS = { - torch.float16: (1e-3, 1e-3), - torch.float32: (1e-4, 1e-5), - torch.float64: (1e-5, 1e-8), -} - - -def _get_default_rtol_and_atol( - actual: torch.Tensor, expected: torch.Tensor -) -> Tuple[float, float]: - actual_rtol, actual_atol = _DTYPE_PRECISIONS.get(actual.dtype, (0.0, 0.0)) - expected_rtol, expected_atol = _DTYPE_PRECISIONS.get(expected.dtype, (0.0, 0.0)) - return max(actual_rtol, expected_rtol), max(actual_atol, expected_atol) - - -@warn_deprecated( - "Use torch.testing.assert_close() instead. " - "For detailed upgrade instructions see https://github.com/pytorch/pytorch/issues/61844." -) -def assert_allclose( - actual: Any, - expected: Any, - rtol: Optional[float] = None, - atol: Optional[float] = None, - equal_nan: bool = True, - msg: str = "", -) -> None: - if not isinstance(actual, torch.Tensor): - actual = torch.tensor(actual) - if not isinstance(expected, torch.Tensor): - expected = torch.tensor(expected, dtype=actual.dtype) - - if rtol is None and atol is None: - rtol, atol = _get_default_rtol_and_atol(actual, expected) - - torch.testing.assert_close( - actual, - expected, - rtol=rtol, - atol=atol, - equal_nan=equal_nan, - check_device=True, - check_dtype=False, - check_stride=False, - msg=msg or None, - )