diff --git a/torch/storage.py b/torch/storage.py index 66e9d92f96c..59023a70106 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -1,4 +1,7 @@ # mypy: allow-untyped-defs + +from __future__ import annotations + import collections import copy import functools @@ -6,6 +9,7 @@ import io import threading import warnings from typing import Any, cast, Dict as _Dict, Optional as _Optional, Type, TypeVar, Union +from typing_extensions import Self import torch from torch._utils import _to, _type @@ -48,7 +52,7 @@ class _StorageBase: def copy_(self, source: T, non_blocking: _Optional[_bool] = None) -> T: raise NotImplementedError - def new(self) -> T: # type: ignore[type-var] + def new(self) -> Union[_StorageBase, TypedStorage]: raise NotImplementedError def nbytes(self) -> _int: @@ -57,10 +61,14 @@ class _StorageBase: def size(self) -> _int: return self.nbytes() - def type(self, dtype: _Optional[str] = None, non_blocking: _bool = False) -> T: # type: ignore[type-var] + def type( + self, dtype: _Optional[str] = None, non_blocking: _bool = False + ) -> Union[_StorageBase, TypedStorage]: return _type(self, dtype, non_blocking) - def cuda(self, device=None, non_blocking=False) -> T: # type: ignore[type-var, misc] # noqa: E704 + def cuda( + self, device=None, non_blocking=False + ) -> Union[_StorageBase, TypedStorage]: """Returns a copy of this object in CUDA memory. If this object is already in CUDA memory and on the correct device, then @@ -75,7 +83,7 @@ class _StorageBase: device2 = torch.device("cuda", device) if device else torch.device("cuda") return self.to(device=device2, non_blocking=non_blocking) - def hpu(self, device=None, non_blocking=False) -> T: # type: ignore[type-var, misc] # noqa: E704 + def hpu(self, device=None, non_blocking=False) -> Union[_StorageBase, TypedStorage]: """Returns a copy of this object in HPU memory. If this object is already in HPU memory and on the correct device, then @@ -141,7 +149,7 @@ class _StorageBase: def _new_with_weak_ptr(cls: Type[T], *args, **kwargs) -> T: raise NotImplementedError - def _shared_decref(self) -> T: # type: ignore[type-var] + def _shared_decref(self) -> Union[_StorageBase, TypedStorage]: raise NotImplementedError def _write_file(self, *args, **kwargs): @@ -150,7 +158,7 @@ class _StorageBase: def resize_(self, size: _int): raise NotImplementedError - def _weak_ref(self, *args, **kwargs) -> T: # type: ignore[type-var] + def _weak_ref(self, *args, **kwargs) -> Union[_StorageBase, TypedStorage]: raise NotImplementedError def _set_from_file(self, *args, **kwargs): @@ -185,11 +193,11 @@ class _StorageBase: raise NotImplementedError @classmethod - def from_file(cls, filename, shared, nbytes) -> T: # type: ignore[type-var] + def from_file(cls, filename, shared, nbytes) -> Union[_StorageBase, TypedStorage]: raise NotImplementedError @classmethod - def _expired(cls, *args, **kwargs) -> T: # type: ignore[type-var] + def _expired(cls, *args, **kwargs) -> Union[_StorageBase, TypedStorage]: raise NotImplementedError def _byteswap(self, *args, **kwargs): @@ -260,7 +268,9 @@ class _StorageBase: storage = storage.clone() return storage - def to(self, *, device: torch.device, non_blocking: _bool = False) -> T: # type: ignore[type-var, misc] # noqa: E704 + def to( + self, *, device: torch.device, non_blocking: _bool = False + ) -> Union[_StorageBase, TypedStorage]: return _to(self, device, non_blocking) def double(self): @@ -852,12 +862,15 @@ class TypedStorage: _warn_typed_storage_removal() return self._untyped_storage - def _new_wrapped_storage(self, untyped_storage): + def _new_wrapped_storage(self, untyped_storage) -> Self: assert type(untyped_storage) == torch.UntypedStorage if type(self) == TypedStorage: - return TypedStorage( - wrap_storage=untyped_storage, dtype=self.dtype, _internal=True + return cast( + Self, + TypedStorage( + wrap_storage=untyped_storage, dtype=self.dtype, _internal=True + ), ) else: return type(self)(wrap_storage=untyped_storage) @@ -982,9 +995,9 @@ class TypedStorage: def copy_(self, source: T, non_blocking: _Optional[bool] = None): _warn_typed_storage_removal() if isinstance(source, TypedStorage): - self._untyped_storage.copy_(source._untyped_storage, non_blocking) # type: ignore[arg-type] + self._untyped_storage.copy_(source._untyped_storage, non_blocking) else: - self._untyped_storage.copy_(source, non_blocking) # type: ignore[arg-type] + self._untyped_storage.copy_(source, non_blocking) return self def nbytes(self): @@ -999,7 +1012,7 @@ class TypedStorage: self, dtype: _Optional[str] = None, non_blocking: bool = False, - ) -> Union[T, str]: + ) -> Union[_StorageBase, TypedStorage, str]: _warn_typed_storage_removal() if dtype is None: legacy_class = self._get_legacy_storage_class() @@ -1012,7 +1025,7 @@ class TypedStorage: else: return self._untyped_storage.type(dtype, non_blocking) - def cuda(self, device=None, non_blocking=False) -> T: # type: ignore[misc,type-var] + def cuda(self, device=None, non_blocking=False) -> Self: _warn_typed_storage_removal() if self.dtype in [ torch.quint8, @@ -1022,12 +1035,10 @@ class TypedStorage: torch.qint8, ]: raise RuntimeError("Cannot create CUDA storage with quantized dtype") - cuda_storage: torch.UntypedStorage = self._untyped_storage.cuda( - device, non_blocking - ) + cuda_storage = self._untyped_storage.cuda(device, non_blocking) return self._new_wrapped_storage(cuda_storage) - def hpu(self, device=None, non_blocking=False) -> T: # type: ignore[misc,type-var] + def hpu(self, device=None, non_blocking=False) -> Self: _warn_typed_storage_removal() if self.dtype in [ torch.quint8, @@ -1037,12 +1048,10 @@ class TypedStorage: torch.qint8, ]: raise RuntimeError("Cannot create HPU storage with quantized dtype") - hpu_storage: torch.UntypedStorage = self._untyped_storage.hpu( - device, non_blocking - ) + hpu_storage = self._untyped_storage.hpu(device, non_blocking) return self._new_wrapped_storage(hpu_storage) - def to(self, *, device: torch.device, non_blocking: bool = False) -> T: # type: ignore[type-var, misc] + def to(self, *, device: torch.device, non_blocking: bool = False) -> Self: _warn_typed_storage_removal() if self.dtype in [ torch.quint8, @@ -1054,9 +1063,7 @@ class TypedStorage: raise RuntimeError( f"Cannot create {device.type.upper()} storage with quantized dtype" ) - to_storage: torch.UntypedStorage = self._untyped_storage.to( - device=device, non_blocking=non_blocking - ) + to_storage = self._untyped_storage.to(device=device, non_blocking=non_blocking) return self._new_wrapped_storage(to_storage) def element_size(self): @@ -1385,7 +1392,7 @@ class TypedStorage: _warn_typed_storage_removal() if cls == TypedStorage: raise RuntimeError("from_file can only be called on derived classes") - untyped_storage: UntypedStorage = UntypedStorage.from_file( + untyped_storage = UntypedStorage.from_file( filename, shared, size * torch._utils._element_size(cls.dtype) ) storage = cls(wrap_storage=untyped_storage)