document torch.testing.assert_allclose (#89526)

After our failed attempt to remove `assert_allclose` in #87974, we decided to add it to the documentation after all. Although we drop the expected removal date, the function continues to be deprecated in favor of `assert_close`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89526
Approved by: https://github.com/mruberry
This commit is contained in:
Philip Meier 2022-12-01 09:28:02 +01:00 committed by PyTorch MergeBot
parent 4baa78bb1f
commit d72cd4c4e5
4 changed files with 60 additions and 89 deletions

View file

@ -6,3 +6,4 @@ torch.testing
.. autofunction:: assert_close
.. autofunction:: make_tensor
.. autofunction:: assert_allclose

View file

@ -1,4 +1,3 @@
from torch._C import FileCheck as FileCheck
from ._comparison import assert_close as assert_close
from ._comparison import assert_allclose, assert_close as assert_close
from ._creation import make_tensor as make_tensor
from ._deprecated import * # noqa: F403

View file

@ -2,10 +2,12 @@ import abc
import cmath
import collections.abc
import contextlib
import warnings
from typing import (
Any,
Callable,
Collection,
Dict,
List,
NoReturn,
Optional,
@ -83,7 +85,8 @@ _DTYPE_PRECISIONS.update(
def default_tolerances(
*inputs: Union[torch.Tensor, torch.dtype]
*inputs: Union[torch.Tensor, torch.dtype],
dtype_precisions: Optional[Dict[torch.dtype, Tuple[float, float]]] = None,
) -> Tuple[float, float]:
"""Returns the default absolute and relative testing tolerances for a set of inputs based on the dtype.
@ -102,7 +105,8 @@ def default_tolerances(
raise TypeError(
f"Expected a torch.Tensor or a torch.dtype, but got {type(input)} instead."
)
rtols, atols = zip(*[_DTYPE_PRECISIONS.get(dtype, (0.0, 0.0)) for dtype in dtypes])
dtype_precisions = dtype_precisions or _DTYPE_PRECISIONS
rtols, atols = zip(*[dtype_precisions.get(dtype, (0.0, 0.0)) for dtype in dtypes])
return max(rtols), max(atols)
@ -1531,3 +1535,55 @@ def assert_close(
check_stride=check_stride,
msg=msg,
)
def assert_allclose(
actual: Any,
expected: Any,
rtol: Optional[float] = None,
atol: Optional[float] = None,
equal_nan: bool = True,
msg: str = "",
) -> None:
"""
.. warning::
:func:`torch.testing.assert_allclose` is deprecated since ``1.12`` and will be removed in a future release.
Please use :func:`torch.testing.assert_close` instead. You can find detailed upgrade instructions
`here <https://github.com/pytorch/pytorch/issues/61844>`_.
"""
warnings.warn(
"`torch.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. "
"Please use `torch.testing.assert_close()` instead. "
"You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.",
FutureWarning,
stacklevel=2,
)
if not isinstance(actual, torch.Tensor):
actual = torch.tensor(actual)
if not isinstance(expected, torch.Tensor):
expected = torch.tensor(expected, dtype=actual.dtype)
if rtol is None and atol is None:
rtol, atol = default_tolerances(
actual,
expected,
dtype_precisions={
torch.float16: (1e-3, 1e-3),
torch.float32: (1e-4, 1e-5),
torch.float64: (1e-5, 1e-8),
},
)
torch.testing.assert_close(
actual,
expected,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
check_device=True,
check_dtype=False,
check_stride=False,
msg=msg or None,
)

View file

@ -1,85 +0,0 @@
"""This module exists since the `torch.testing` exposed a lot of stuff that shouldn't have been public. Although this
was never documented anywhere, some other internal FB projects as well as downstream OSS projects might use this. Thus,
we don't internalize without warning, but still go through a deprecation cycle.
"""
import functools
import warnings
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
__all__ = ["assert_allclose"]
def warn_deprecated(
instructions: Union[str, Callable[[str, Tuple[Any, ...], Dict[str, Any], Any], str]]
) -> Callable:
def outer_wrapper(fn: Callable) -> Callable:
name = fn.__name__
head = f"torch.testing.{name}() is deprecated since 1.12 and will be removed in 1.14. "
@functools.wraps(fn)
def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
return_value = fn(*args, **kwargs)
tail = (
instructions(name, args, kwargs, return_value)
if callable(instructions)
else instructions
)
msg = (head + tail).strip()
warnings.warn(msg, FutureWarning)
return return_value
return inner_wrapper
return outer_wrapper
_DTYPE_PRECISIONS = {
torch.float16: (1e-3, 1e-3),
torch.float32: (1e-4, 1e-5),
torch.float64: (1e-5, 1e-8),
}
def _get_default_rtol_and_atol(
actual: torch.Tensor, expected: torch.Tensor
) -> Tuple[float, float]:
actual_rtol, actual_atol = _DTYPE_PRECISIONS.get(actual.dtype, (0.0, 0.0))
expected_rtol, expected_atol = _DTYPE_PRECISIONS.get(expected.dtype, (0.0, 0.0))
return max(actual_rtol, expected_rtol), max(actual_atol, expected_atol)
@warn_deprecated(
"Use torch.testing.assert_close() instead. "
"For detailed upgrade instructions see https://github.com/pytorch/pytorch/issues/61844."
)
def assert_allclose(
actual: Any,
expected: Any,
rtol: Optional[float] = None,
atol: Optional[float] = None,
equal_nan: bool = True,
msg: str = "",
) -> None:
if not isinstance(actual, torch.Tensor):
actual = torch.tensor(actual)
if not isinstance(expected, torch.Tensor):
expected = torch.tensor(expected, dtype=actual.dtype)
if rtol is None and atol is None:
rtol, atol = _get_default_rtol_and_atol(actual, expected)
torch.testing.assert_close(
actual,
expected,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
check_device=True,
check_dtype=False,
check_stride=False,
msg=msg or None,
)