mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
document torch.testing.assert_allclose (#89526)
After our failed attempt to remove `assert_allclose` in #87974, we decided to add it to the documentation after all. Although we drop the expected removal date, the function continues to be deprecated in favor of `assert_close`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89526 Approved by: https://github.com/mruberry
This commit is contained in:
parent
4baa78bb1f
commit
d72cd4c4e5
4 changed files with 60 additions and 89 deletions
|
|
@ -6,3 +6,4 @@ torch.testing
|
|||
|
||||
.. autofunction:: assert_close
|
||||
.. autofunction:: make_tensor
|
||||
.. autofunction:: assert_allclose
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from torch._C import FileCheck as FileCheck
|
||||
from ._comparison import assert_close as assert_close
|
||||
from ._comparison import assert_allclose, assert_close as assert_close
|
||||
from ._creation import make_tensor as make_tensor
|
||||
from ._deprecated import * # noqa: F403
|
||||
|
|
|
|||
|
|
@ -2,10 +2,12 @@ import abc
|
|||
import cmath
|
||||
import collections.abc
|
||||
import contextlib
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
List,
|
||||
NoReturn,
|
||||
Optional,
|
||||
|
|
@ -83,7 +85,8 @@ _DTYPE_PRECISIONS.update(
|
|||
|
||||
|
||||
def default_tolerances(
|
||||
*inputs: Union[torch.Tensor, torch.dtype]
|
||||
*inputs: Union[torch.Tensor, torch.dtype],
|
||||
dtype_precisions: Optional[Dict[torch.dtype, Tuple[float, float]]] = None,
|
||||
) -> Tuple[float, float]:
|
||||
"""Returns the default absolute and relative testing tolerances for a set of inputs based on the dtype.
|
||||
|
||||
|
|
@ -102,7 +105,8 @@ def default_tolerances(
|
|||
raise TypeError(
|
||||
f"Expected a torch.Tensor or a torch.dtype, but got {type(input)} instead."
|
||||
)
|
||||
rtols, atols = zip(*[_DTYPE_PRECISIONS.get(dtype, (0.0, 0.0)) for dtype in dtypes])
|
||||
dtype_precisions = dtype_precisions or _DTYPE_PRECISIONS
|
||||
rtols, atols = zip(*[dtype_precisions.get(dtype, (0.0, 0.0)) for dtype in dtypes])
|
||||
return max(rtols), max(atols)
|
||||
|
||||
|
||||
|
|
@ -1531,3 +1535,55 @@ def assert_close(
|
|||
check_stride=check_stride,
|
||||
msg=msg,
|
||||
)
|
||||
|
||||
|
||||
def assert_allclose(
|
||||
actual: Any,
|
||||
expected: Any,
|
||||
rtol: Optional[float] = None,
|
||||
atol: Optional[float] = None,
|
||||
equal_nan: bool = True,
|
||||
msg: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
.. warning::
|
||||
|
||||
:func:`torch.testing.assert_allclose` is deprecated since ``1.12`` and will be removed in a future release.
|
||||
Please use :func:`torch.testing.assert_close` instead. You can find detailed upgrade instructions
|
||||
`here <https://github.com/pytorch/pytorch/issues/61844>`_.
|
||||
"""
|
||||
warnings.warn(
|
||||
"`torch.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. "
|
||||
"Please use `torch.testing.assert_close()` instead. "
|
||||
"You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if not isinstance(actual, torch.Tensor):
|
||||
actual = torch.tensor(actual)
|
||||
if not isinstance(expected, torch.Tensor):
|
||||
expected = torch.tensor(expected, dtype=actual.dtype)
|
||||
|
||||
if rtol is None and atol is None:
|
||||
rtol, atol = default_tolerances(
|
||||
actual,
|
||||
expected,
|
||||
dtype_precisions={
|
||||
torch.float16: (1e-3, 1e-3),
|
||||
torch.float32: (1e-4, 1e-5),
|
||||
torch.float64: (1e-5, 1e-8),
|
||||
},
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
actual,
|
||||
expected,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
equal_nan=equal_nan,
|
||||
check_device=True,
|
||||
check_dtype=False,
|
||||
check_stride=False,
|
||||
msg=msg or None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,85 +0,0 @@
|
|||
"""This module exists since the `torch.testing` exposed a lot of stuff that shouldn't have been public. Although this
|
||||
was never documented anywhere, some other internal FB projects as well as downstream OSS projects might use this. Thus,
|
||||
we don't internalize without warning, but still go through a deprecation cycle.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
__all__ = ["assert_allclose"]
|
||||
|
||||
|
||||
def warn_deprecated(
|
||||
instructions: Union[str, Callable[[str, Tuple[Any, ...], Dict[str, Any], Any], str]]
|
||||
) -> Callable:
|
||||
def outer_wrapper(fn: Callable) -> Callable:
|
||||
name = fn.__name__
|
||||
head = f"torch.testing.{name}() is deprecated since 1.12 and will be removed in 1.14. "
|
||||
|
||||
@functools.wraps(fn)
|
||||
def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
return_value = fn(*args, **kwargs)
|
||||
tail = (
|
||||
instructions(name, args, kwargs, return_value)
|
||||
if callable(instructions)
|
||||
else instructions
|
||||
)
|
||||
msg = (head + tail).strip()
|
||||
warnings.warn(msg, FutureWarning)
|
||||
return return_value
|
||||
|
||||
return inner_wrapper
|
||||
|
||||
return outer_wrapper
|
||||
|
||||
|
||||
_DTYPE_PRECISIONS = {
|
||||
torch.float16: (1e-3, 1e-3),
|
||||
torch.float32: (1e-4, 1e-5),
|
||||
torch.float64: (1e-5, 1e-8),
|
||||
}
|
||||
|
||||
|
||||
def _get_default_rtol_and_atol(
|
||||
actual: torch.Tensor, expected: torch.Tensor
|
||||
) -> Tuple[float, float]:
|
||||
actual_rtol, actual_atol = _DTYPE_PRECISIONS.get(actual.dtype, (0.0, 0.0))
|
||||
expected_rtol, expected_atol = _DTYPE_PRECISIONS.get(expected.dtype, (0.0, 0.0))
|
||||
return max(actual_rtol, expected_rtol), max(actual_atol, expected_atol)
|
||||
|
||||
|
||||
@warn_deprecated(
|
||||
"Use torch.testing.assert_close() instead. "
|
||||
"For detailed upgrade instructions see https://github.com/pytorch/pytorch/issues/61844."
|
||||
)
|
||||
def assert_allclose(
|
||||
actual: Any,
|
||||
expected: Any,
|
||||
rtol: Optional[float] = None,
|
||||
atol: Optional[float] = None,
|
||||
equal_nan: bool = True,
|
||||
msg: str = "",
|
||||
) -> None:
|
||||
if not isinstance(actual, torch.Tensor):
|
||||
actual = torch.tensor(actual)
|
||||
if not isinstance(expected, torch.Tensor):
|
||||
expected = torch.tensor(expected, dtype=actual.dtype)
|
||||
|
||||
if rtol is None and atol is None:
|
||||
rtol, atol = _get_default_rtol_and_atol(actual, expected)
|
||||
|
||||
torch.testing.assert_close(
|
||||
actual,
|
||||
expected,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
equal_nan=equal_nan,
|
||||
check_device=True,
|
||||
check_dtype=False,
|
||||
check_stride=False,
|
||||
msg=msg or None,
|
||||
)
|
||||
Loading…
Reference in a new issue