diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 5e9faf9fe9e..f4b4c621eb4 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -1032,7 +1032,6 @@ "List", "Number", "Sequence", - "Tuple", "Union" ], "torch.utils.benchmark.utils.compare": [ diff --git a/torch/_C/_cudnn.pyi b/torch/_C/_cudnn.pyi index 4eeefdd2e5d..cfea3f956f2 100644 --- a/torch/_C/_cudnn.pyi +++ b/torch/_C/_cudnn.pyi @@ -1,16 +1,13 @@ -from enum import Enum - -from torch.types import _bool +from enum import IntEnum # Defined in torch/csrc/cuda/shared/cudnn.cpp -is_cuda: _bool +is_cuda: bool def getRuntimeVersion() -> tuple[int, int, int]: ... def getCompileVersion() -> tuple[int, int, int]: ... def getVersionInt() -> int: ... -class RNNMode(int, Enum): - value: int +class RNNMode(IntEnum): rnn_relu = ... rnn_tanh = ... lstm = ... diff --git a/torch/types.py b/torch/types.py index ff0956f9517..ab6f4639f44 100644 --- a/torch/types.py +++ b/torch/types.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - # In some cases, these basic types are shadowed by corresponding # top-level values. The underscore variants let us refer to these # types. See https://github.com/python/mypy/issues/4146 for why these @@ -15,7 +13,7 @@ from builtins import ( # noqa: F401 ) from collections.abc import Sequence from typing import Any, IO, TYPE_CHECKING, Union -from typing_extensions import TypeAlias +from typing_extensions import Self, TypeAlias # `as` imports have better static analysis support than assignment `ExposedType: TypeAlias = HiddenType` from torch import ( # noqa: F401 @@ -59,7 +57,7 @@ FloatLikeType: TypeAlias = Union[float, SymFloat] # bool or SymBool BoolLikeType: TypeAlias = Union[bool, SymBool] -py_sym_types = (SymInt, SymFloat, SymBool) +py_sym_types = (SymInt, SymFloat, SymBool) # left un-annotated intentionally PySymType: TypeAlias = Union[SymInt, SymFloat, SymBool] # Meta-type for "numeric" things; matches our docs @@ -83,10 +81,10 @@ class Storage: dtype: _dtype _torch_load_uninitialized: bool - def __deepcopy__(self, memo: dict[int, Any]) -> "Storage": + def __deepcopy__(self, memo: dict[int, Any]) -> Self: raise NotImplementedError - def _new_shared(self, size: int) -> "Storage": + def _new_shared(self, size: int) -> Self: raise NotImplementedError def _write_file( @@ -104,13 +102,13 @@ class Storage: def is_shared(self) -> bool: raise NotImplementedError - def share_memory_(self) -> "Storage": + def share_memory_(self) -> Self: raise NotImplementedError def nbytes(self) -> int: raise NotImplementedError - def cpu(self) -> "Storage": + def cpu(self) -> Self: raise NotImplementedError def data_ptr(self) -> int: @@ -121,12 +119,12 @@ class Storage: filename: str, shared: bool = False, nbytes: int = 0, - ) -> "Storage": + ) -> Self: raise NotImplementedError def _new_with_file( self, f: Any, element_size: int, - ) -> "Storage": + ) -> Self: raise NotImplementedError