diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 6e106468e4d..0b57d7f5771 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -1204,7 +1204,7 @@ def gen_pyi( ], "set_": [ "def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage], " - "offset: _int, size: _size, stride: _size) -> Tensor: ...", + "offset: _int, size: _symsize, stride: _symsize) -> Tensor: ...", "def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage]) -> Tensor: ...", ], "split": [ diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index e7508cc0fdb..db2e65e0622 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -56,6 +56,7 @@ from torch.types import ( _qscheme, _size, _str, + _symsize, ) from torch.utils._python_dispatch import TorchDispatchMode @@ -1661,6 +1662,18 @@ class _SetExcludeDispatchKeyGuard: def __enter__(self): ... def __exit__(self, exc_type, exc_value, traceback): ... +# Defined in torch/csrc/utils/schema_info.h + +class _SchemaInfo: + def __init__(self, schema: _int) -> None: ... + + @overload + def is_mutable(self) -> _bool: ... + @overload + def is_mutable(self, name: str) -> _bool: ... + + def has_argument(self, name: str) -> _bool: ... + # Defined in torch/csrc/utils/init.cpp class BenchmarkConfig: num_calling_threads: _int diff --git a/torch/__init__.py b/torch/__init__.py index 8830d79412b..6e81b1ef1de 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -36,6 +36,9 @@ from typing import ( ) from typing_extensions import ParamSpec as _ParamSpec, TypeGuard as _TypeGuard +if TYPE_CHECKING: + from .types import IntLikeType + # multipy/deploy is setting this import before importing torch, this is the most # reliable way we have to detect if we're running within deploy. @@ -471,6 +474,9 @@ class SymInt: def __add__(self, other) -> "SymInt": raise TypeError("type stub not overridden") + def __mod__(self, other: "IntLikeType") -> "SymInt": + raise TypeError("type stub not overridden") + def __mul__(self, other) -> "SymInt": raise TypeError("type stub not overridden") @@ -504,6 +510,9 @@ class SymInt: def __neg__(self): raise TypeError("type stub not overridden") + def __sub__(self, other: "IntLikeType") -> "SymInt": + raise TypeError("type stub not overridden") + def __repr__(self): return self.node._graph_repr() diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 79bbb493865..dec73a6e6d1 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -165,6 +165,7 @@ def _apply_func_to_inner_tensors_of_same_dim(func, t, *args, **kwargs): assert is_traceable_wrapper_subclass(t) attrs, ctx = t.__tensor_flatten__() + assert isinstance(t, torch.Tensor) for attr in attrs: inner = getattr(t, attr) if inner.dim() == t.dim(): diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index bd7ffe0f195..5c861dac25c 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -83,6 +83,7 @@ def fakify( constraint_sizes=[None] * n_dims, ) t_id = id(t) + assert mode.shape_env is not None if t_id in t_constraints: for i, constraint in t_constraints[t_id].items(): symbolic_context.constraint_sizes[i] = constraint.constraint_range @@ -256,6 +257,7 @@ def produce_guards_and_solve_constraints( _disable_forced_specializations: if True, avoids forced specializations """ shape_env = fake_mode.shape_env + assert shape_env is not None assert shape_env.tracked_fakes is not None placeholders = [tf.fake for tf in shape_env.tracked_fakes] @@ -322,6 +324,7 @@ def make_constraints( """ shape_env = fake_mode.shape_env + assert shape_env is not None inline_constraints = gm.meta.get("inline_constraints", []) range_constraints = { symbol: inline_constraints[symbol] for symbol in inline_constraints diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 34e83c0b7fc..f9ad2885d29 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -12,7 +12,7 @@ import pprint from contextlib import nullcontext from dataclasses import dataclass, field from functools import wraps -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union import torch import torch.utils.dlpack @@ -1450,7 +1450,7 @@ Expected metadata: {str(expected_tangent_metadata)} Runtime metadata: {str(runtime_tangent_metadata)} -shape: {str(x.shape)} +shape: {str(cast(torch.Tensor, x).shape)} To fix this, your tensor subclass must implement the dunder method __force_to_same_metadata__. """ ) @@ -1830,14 +1830,16 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa ) assert CompiledFunction.metadata.traced_tangent_metas is not None all_args = [ - AOTDispatchAutograd.coerce_runtime_tangent( - t, - CompiledFunction.metadata.traced_tangent_metas[ - i - tangents_start_idx - ], + ( + AOTDispatchAutograd.coerce_runtime_tangent( + t, + CompiledFunction.metadata.traced_tangent_metas[ + i - tangents_start_idx + ], + ) + if tangents_start_idx <= i < tangents_end_idx + else t ) - if tangents_start_idx <= i < tangents_end_idx - else t for i, t in enumerate(all_args) ] all_args = unwrap_tensor_subclasses( @@ -1849,9 +1851,11 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa # Make the tangents contiguous. Note that we must do this after subclass desugaring # because inputs to inductor have to be contiguous all_args = [ - AOTDispatchAutograd._force_contiguous(t) - if (tangents_start_idx <= i < tangents_end_idx) - else t + ( + AOTDispatchAutograd._force_contiguous(t) + if (tangents_start_idx <= i < tangents_end_idx) + else t + ) for i, t in enumerate(all_args) ] diff --git a/torch/_functorch/_aot_autograd/subclass_utils.py b/torch/_functorch/_aot_autograd/subclass_utils.py index 4dfabaa4006..aa7adcc221f 100644 --- a/torch/_functorch/_aot_autograd/subclass_utils.py +++ b/torch/_functorch/_aot_autograd/subclass_utils.py @@ -5,6 +5,7 @@ AOTAutograd's responsibility is to trace through all pytorch capabilities that l and this includes tensor subclasses that implement __torch_dispatch__. """ +import typing from typing import Any, List, Optional, Tuple, Union import torch.utils._pytree as pytree @@ -115,7 +116,7 @@ def unwrap_tensor_subclasses(wrapped_args, *, is_joint_structure: bool): xs_inner = [] for x in xs: if is_traceable_wrapper_subclass(x): - xs_inner.extend(get_plain_tensors(x)) + xs_inner.extend(get_plain_tensors(typing.cast(Tensor, x))) else: xs_inner.append(x) return xs_inner diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 8135c9ebc55..0e3bb491b18 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING functionalize_rng_ops = False # can be useful for debugging if we are incorrectly creating meta fake tensors -fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", True) +fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", "1") != "0" # Enables optional asserts in hotpath code to check for errors. If # you are seeing weird accuracy problems, try turning this on. @@ -24,7 +24,7 @@ fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", True) # but it is on by default for aot_eager. debug_assert = False -debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", False) +debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", "0") != "0" # Today, if you are in a situation where there is "false aliasing" # (e.g. you have a bunch of model parameters that all alias the same underlying buffer), diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 714f13c30ea..ab07ad4c28c 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1,36 +1,41 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + import contextlib import functools import logging import math import os import traceback +import typing import weakref + from collections import defaultdict from dataclasses import dataclass from typing import ( - Any, Callable, cast, Dict, + Generator, List, + Literal, + Mapping, Optional, + Sequence, + Set, Tuple, Type, TYPE_CHECKING, TypeVar, Union, ) -from typing_extensions import TypeGuard +from typing_extensions import Self, TypeGuard from weakref import ReferenceType import torch import torch._custom_op -import torch._logging -from torch._C._functorch import is_functorch_wrapped_tensor, is_legacy_batchedtensor -from torch._guards import Source -from torch._ops import OpOverload +from torch import SymBool, SymFloat, SymInt, Tensor +from torch._C._functorch import is_functorch_wrapped_tensor, is_legacy_batchedtensor from torch._prims_common import suggest_memory_format from torch._subclasses.meta_utils import ( assert_eq, @@ -40,6 +45,7 @@ from torch._subclasses.meta_utils import ( MetaConverter, ) from torch._utils import render_call +from torch.fx.immutable_collections import immutable_dict from torch.fx.operator_schemas import normalize_function from torch.multiprocessing.reductions import StorageWeakRef from torch.overrides import TorchFunctionMode @@ -48,36 +54,17 @@ from torch.utils._python_dispatch import ( is_traceable_wrapper_subclass, TorchDispatchMode, ) -from torch.utils._pytree import PyTree, tree_map, tree_map_ +from torch.utils._pytree import PyTree, tree_map, tree_map_, TreeSpec from torch.utils._stats import count from torch.utils._traceback import CapturedTraceback if TYPE_CHECKING: - from torch.fx.experimental.symbolic_shapes import ShapeEnv - from torch.types import _bool + from types import TracebackType - -class _Unassigned: - pass - - -def _is_plain_tensor(t): - return ( - type(t) is torch.Tensor - and t.layout == torch.strided - and not ( - t.is_sparse - or t.is_nested - or is_functorch_wrapped_tensor(t) - or is_legacy_batchedtensor(t) - or torch._is_functional_tensor(t) - ) - ) - - -_UNASSIGNED = _Unassigned() - -DimList = List + from torch._guards import Source + from torch._ops import OpOverload + from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext + from torch.types import IntLikeType log = logging.getLogger(__name__) @@ -91,9 +78,17 @@ except ValueError as e: else: raise e + +class _Unassigned: + pass + + +_UNASSIGNED = _Unassigned() + +DimList = List + pytree = torch.utils._pytree T = TypeVar("T") -TensorWeakRef = Any aten = torch._ops.ops.aten @@ -107,11 +102,11 @@ RECURSION_COUNT = 0 # if you don't want to increase indentation which is # what a context manager would do. class IncrementRecursionCount: - def __init__(self): + def __init__(self) -> None: global RECURSION_COUNT RECURSION_COUNT += 1 - def __del__(self): + def __del__(self) -> None: global RECURSION_COUNT RECURSION_COUNT -= 1 @@ -136,12 +131,12 @@ class UnsupportedOperatorException(RuntimeError): func: OpOverload -def ordered_set(*items): +def ordered_set(*items: T) -> Dict[T, Literal[True]]: return dict.fromkeys(items, True) @contextlib.contextmanager -def unset_fake_temporarily(): +def unset_fake_temporarily() -> Generator[Optional[TorchDispatchMode], None, None]: old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) try: yield old @@ -150,7 +145,7 @@ def unset_fake_temporarily(): torch._C._set_dispatch_mode(old) -def get_plain_tensors(subclass): +def get_plain_tensors(subclass: Tensor) -> List[Tensor]: assert is_traceable_wrapper_subclass(subclass) plain_tensors = [] todo = [subclass] @@ -166,7 +161,7 @@ def get_plain_tensors(subclass): return plain_tensors -def is_fake(x: object) -> TypeGuard[torch.Tensor]: +def is_fake(x: object) -> TypeGuard[Tensor]: if isinstance(x, FakeTensor): return True if is_traceable_wrapper_subclass(x): @@ -176,17 +171,17 @@ def is_fake(x: object) -> TypeGuard[torch.Tensor]: any_fake = any(is_fake(x) for x in flattened_tensors) assert all_fake == any_fake, "got mixed fake and real tensors!" return all_fake - elif isinstance(x, torch.Tensor) and torch._is_functional_tensor(x): + elif isinstance(x, Tensor) and torch._is_functional_tensor(x): reapply_views = torch._C._functionalization_reapply_views_tls() unwrapped = torch._C._functorch._unwrap_functional_tensor(x, reapply_views) return is_fake(unwrapped) - elif isinstance(x, torch.Tensor) and is_functorch_wrapped_tensor(x): + elif isinstance(x, Tensor) and is_functorch_wrapped_tensor(x): unwrapped = torch._C._functorch.get_unwrapped(x) return is_fake(unwrapped) return False -def maybe_get_fake_mode(t): +def maybe_get_fake_mode(t: object) -> Optional[FakeTensorMode]: if isinstance(t, FakeTensor): return t.fake_mode if is_traceable_wrapper_subclass(t): @@ -197,19 +192,19 @@ def maybe_get_fake_mode(t): m = modes[0] assert all(m is x for x in modes) return m - elif isinstance(t, torch.Tensor) and torch._is_functional_tensor(t): + elif isinstance(t, Tensor) and torch._is_functional_tensor(t): reapply_views = torch._C._functionalization_reapply_views_tls() unwrapped = torch._C._functorch._unwrap_functional_tensor(t, reapply_views) return maybe_get_fake_mode(unwrapped) - elif isinstance(t, torch.Tensor) and is_functorch_wrapped_tensor(t): + elif isinstance(t, Tensor) and is_functorch_wrapped_tensor(t): unwrapped = torch._C._functorch.get_unwrapped(t) return maybe_get_fake_mode(unwrapped) return None @functools.lru_cache(None) -def get_schema_info(func): - return torch._C._SchemaInfo(func._schema) # type: ignore[attr-defined] +def get_schema_info(func: OpOverload) -> torch._C._SchemaInfo: + return torch._C._SchemaInfo(func._schema) # many of the decompositions registered to torch/_prims do not at the moment model @@ -218,7 +213,7 @@ def get_schema_info(func): # decomps are used for aot autograd tracing so we would like to unify on their # implementation and add additional testing to them @functools.lru_cache(None) -def torch_decomp_decompositions(func): +def torch_decomp_decompositions(func: OpOverload) -> bool: from torch._decomp import decomposition_table decompositions = torch._decomp.decompositions @@ -230,32 +225,50 @@ def torch_decomp_decompositions(func): ) and decomposition_table[func].__name__ in dir(decompositions) -def tree_flatten_only(ty: Type[T], tree: PyTree): +def tree_flatten_only(ty: Type[T], tree: PyTree) -> List[T]: flat_vals = pytree.tree_leaves(tree) return [elem for elem in flat_vals if isinstance(elem, ty)] +def _is_plain_tensor(t: object) -> bool: + return ( + type(t) is Tensor + and t.layout == torch.strided + and not ( + t.is_sparse + or t.is_nested + or is_functorch_wrapped_tensor(t) + or is_legacy_batchedtensor(t) + or torch._is_functional_tensor(t) + ) + ) + + # Similar to `MetaConverter`, this is a class for converting # multiple tensors into fake tensors which share the same view/storage # structure. Like `MetaConverter`, it uses `WeakIdRef` to # hold a weak reference for all memoized tensors. class FakeTensorConverter: @property - def tensor_memo(self): + def tensor_memo( + self, + ) -> weakref.WeakValueDictionary: + # not valid until py3.10 + # weakref.WeakValueDictionary["torch._subclasses.meta_utils.MetaTensorId", Optional["FakeTensor"]] return self.meta_converter.tensor_memo meta_converter: MetaConverter constant_storage_mapping: Dict[StorageWeakRef, List[ReferenceType]] export: bool - def __init__(self, *, copy_data=False, export=False): + def __init__(self, *, copy_data: bool = False, export: bool = False) -> None: self.meta_converter = MetaConverter(copy_data=copy_data) self.export = export # map from to storage to corresponding constant tensors self.constant_storage_mapping = {} - def add_constant_storage_mapping(self, fake_tensor): + def add_constant_storage_mapping(self, fake_tensor: FakeTensor) -> None: # when you have a constant, aliased tensor: # const_tensor.add_(torch.rand([1])) # all aliases of it must become no longer const @@ -269,7 +282,7 @@ class FakeTensorConverter: self.constant_storage_mapping[weak_st] = [] self.constant_storage_mapping[weak_st].append(weakref.ref(fake_tensor)) - def invalidate_constant_aliases(self, tensor): + def invalidate_constant_aliases(self, tensor: Tensor) -> None: assert not isinstance(tensor, FakeTensor) weak_st = StorageWeakRef(tensor._typed_storage()) @@ -284,13 +297,13 @@ class FakeTensorConverter: del self.constant_storage_mapping[weak_st] - def _get_memo(self, t): + def _get_memo(self, t: Tensor) -> Optional[FakeTensor]: tid = self.meta_converter.describer.lookup_tensor.get(t) if tid is None: return None return self.tensor_memo.get(tid) - def set_tensor_memo(self, t, v): + def set_tensor_memo(self, t: Tensor, v: FakeTensor) -> None: tid = self.meta_converter.describer.get_tensor_id(t) self.meta_converter.tensor_memo[tid] = v @@ -302,20 +315,25 @@ class FakeTensorConverter: # cross ref testing and the inner test is already operating on meta tensors. def from_real_tensor( self, - fake_mode, - t, - make_constant=False, - shape_env=None, + fake_mode: FakeTensorMode, + t: Tensor, + make_constant: bool = False, + shape_env: Optional[ShapeEnv] = None, *, - source=None, - symbolic_context=None, - trace=True, - ): + source: Optional[Source] = None, + symbolic_context: Optional[SymbolicContext] = None, + trace: bool = True, + ) -> FakeTensor: # see note [Tensor Fakification and Symbol Caching] if not symbolic_context and not source and shape_env: if tracing_context := torch._guards.TracingContext.try_get(): if t in tracing_context.tensor_to_context: symbolic_context = tracing_context.tensor_to_context[t] + from torch.fx.experimental.symbolic_shapes import ( + StatefulSymbolicContext, + ) + + assert isinstance(symbolic_context, StatefulSymbolicContext) source = symbolic_context.tensor_source maybe_memo = self._get_memo(t) @@ -328,7 +346,7 @@ class FakeTensorConverter: if type(t) is torch.nn.Parameter: assert not make_constant - def mk_fake_tensor(make_meta_t): + def mk_fake_tensor(make_meta_t: Callable[[], object]) -> FakeTensor: # NB: don't use in_kernel_invocation_manager. to # ensure FakeTensor can internally do constant computation # as necessary. Invocation manager is "more correct" as @@ -432,7 +450,9 @@ class FakeTensorConverter: return out # If you specify the device, it MUST be a meta tensor. - def from_meta_and_device(self, fake_mode, t, device): + def from_meta_and_device( + self, fake_mode: FakeTensorMode, t: Tensor, device: torch.device + ) -> FakeTensor: assert ( t.device.type == "meta" ), f"tensor's device must be `meta`, got {t.device.type} instead" @@ -447,7 +467,7 @@ class FakeTensorConverter: @functools.lru_cache(None) -def init_cuda_context(): +def init_cuda_context() -> None: # Backward will error with cuda Fake Tensors if no cuda tensors have been initialized first if torch.cuda.is_available(): ( @@ -458,7 +478,9 @@ def init_cuda_context(): @contextlib.contextmanager -def in_kernel_invocation_manager(fake_mode): +def in_kernel_invocation_manager( + fake_mode: FakeTensorMode, +) -> Generator[None, None, None]: # See: note [Fake Tensor Dispatch Keys] prev_in_kernel = fake_mode.in_kernel_invocation meta_in_tls = torch._C._meta_in_tls_dispatch_include() @@ -478,7 +500,7 @@ def in_kernel_invocation_manager(fake_mode): # Return if the function allows Python numbers to bind to Tensors -def should_allow_numbers_as_tensors(func: OpOverload): +def should_allow_numbers_as_tensors(func: OpOverload) -> bool: return torch._C._should_allow_numbers_as_tensors( func.name().split("::")[-1].split(".")[0] ) @@ -502,23 +524,25 @@ class FakeTensorConfig: class UnbackedMemoDescriptor: _name: str - def __set_name__(self, owner, name): + def __set_name__(self, owner: str, name: str) -> None: self._name = name - def _memo(self, obj): + def _memo(self, obj: FakeTensor) -> str: return f"_{self._name}" - def _memo_vc(self, obj): + def _memo_vc(self, obj: FakeTensor) -> str: return f"_{self._name}_vc" # When we retrace, we need to invalidate all the memos so that we can # accurately identify the first time unbacked SymInts are allocated. # This is only relevant for inputs; for intermediates, they will get fresh # fake tensors so you won't have a memo anyway - def _memo_epoch(self, obj): + def _memo_epoch(self, obj: FakeTensor) -> str: return f"_{self._name}_epoch" - def __get__(self, obj: "FakeTensor", objtype=None): + def __get__( + self, obj: FakeTensor, objtype: Optional[Type[FakeTensor]] = None + ) -> Optional[object]: if (r := getattr(obj, self._memo(obj))) is None: return None # Version counter based tracking isn't 100% sound but it's close @@ -531,7 +555,7 @@ class UnbackedMemoDescriptor: return None return r - def __set__(self, obj, value): + def __set__(self, obj: FakeTensor, value: Optional[object]) -> None: if value is None: setattr(obj, self._memo(obj), None) setattr(obj, self._memo_vc(obj), None) @@ -542,7 +566,7 @@ class UnbackedMemoDescriptor: setattr(obj, self._memo_epoch(obj), obj.fake_mode.epoch) -class FakeTensor(torch.Tensor): +class FakeTensor(Tensor): """ Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. @@ -552,9 +576,9 @@ class FakeTensor(torch.Tensor): """ fake_device: torch.device - fake_mode: "FakeTensorMode" - constant: Optional[torch.Tensor] - real_tensor: Optional[torch.Tensor] + fake_mode: FakeTensorMode + constant: Optional[Tensor] + real_tensor: Optional[Tensor] # TODO: Generalize this as needed, e.g., into a trie of memos, if # you do something like x[0].item() (x[0] is fresh each time, so @@ -568,18 +592,22 @@ class FakeTensor(torch.Tensor): _mode_key = torch._C._TorchDispatchModeKey.FAKE @property - def device(self): + def device(self) -> torch.device: if self.fake_mode.in_kernel_invocation: return torch.device("meta") else: return self.fake_device + @device.setter + def device(self, _: torch.device) -> None: + raise NotImplementedError + # Note: [Fake Tensor Dispatch Keys] # In order to model the behavior of device-specific autocast # and autograd logic, we update the dispatch keys of FakeTensors # to reflect their fake device. This includes the BackendComponent # (DispatchKey::Meta -> DispatchKey::CUDA), and also the BackendComponent - # related Autocast and Autograd keys. __torch__dispatch__ sits below + # related Autocast and Autograd keys. __torch_dispatch__ sits below # Autocast and Autograd, and is only invoked when we are at the # kernel for the BackendComponent. Then, we add Meta to the # thread-local dispatch include set to hit the meta kernel @@ -591,14 +619,25 @@ class FakeTensor(torch.Tensor): # We don't support named tensors; graph break @property - def names(self): + def names(self) -> List[str]: raise UnsupportedFakeTensorException( "torch.compile doesn't support named tensors" ) + @names.setter + def names(self, _: List[str]) -> None: + raise NotImplementedError + @staticmethod - def __new__(cls, fake_mode, elem, device, constant=None, real_tensor=None): - self = torch.Tensor._make_subclass( + def __new__( + cls, + fake_mode: FakeTensorMode, + elem: Tensor, + device: torch.device, + constant: Optional[Tensor] = None, + real_tensor: Optional[Tensor] = None, + ) -> Self: + self = Tensor._make_subclass( cls, elem, elem.requires_grad, @@ -636,11 +675,11 @@ class FakeTensor(torch.Tensor): ) else: device = torch.device(f"{device.type}:0") - self.fake_device = device # type: ignore[attr-defined] - self.fake_mode = fake_mode # type: ignore[attr-defined] - self.constant = constant # type: ignore[attr-defined] + self.fake_device = device + self.fake_mode = fake_mode + self.constant = constant assert not isinstance(real_tensor, FakeTensor) - self.real_tensor = real_tensor # type: ignore[attr-defined] + self.real_tensor = real_tensor self.nonzero_memo = None self.item_memo = None self.unique_memo = None @@ -649,7 +688,7 @@ class FakeTensor(torch.Tensor): self._debug_trace = CapturedTraceback.extract() # type: ignore[attr-defined] return self - # In some circumstances, a conventional torch.Tensor constructor + # In some circumstances, a conventional Tensor constructor # will get rewritten to call into FakeTensor. We must provide an # __init__ method that can accept the Python interpreters initialization # in such a situation; we must also be able to handle direct fake @@ -658,25 +697,31 @@ class FakeTensor(torch.Tensor): # In particular, the __init__ call will look funny in the following case: # # with FakeTensorMode(): - # x = torch.Tensor([1, 2, 3]) + # x = Tensor([1, 2, 3]) # # this desugars into: # # with FakeTensorMode(): - # x = torch.Tensor.__new__([1, 2, 3]) + # x = Tensor.__new__([1, 2, 3]) # # NB: x is a fake tensor, because of the mode! # x.__init__([1, 2, 3]) # not the normal fake tensor args! # - def __init__(self, *args, **kwargs): + def __init__(self, *args: object, **kwargs: object) -> None: super().__init__() @staticmethod - def from_tensor(t, fake_mode): + def from_tensor(t: Tensor, fake_mode: FakeTensorMode) -> FakeTensor: return fake_mode.from_tensor(t) @classmethod @count - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def __torch_dispatch__( + cls, + func: OpOverload, + types: Sequence[Type], + args: Sequence[object] = (), + kwargs: Mapping[str, object] = immutable_dict(), + ) -> object: # need to handle here to avoid infinite recursion # see [in_kernel_invocation] if func == torch.ops.prim.device.default: @@ -704,7 +749,7 @@ class FakeTensor(torch.Tensor): # subclasses of tensors to dispatch, and any FakeTensor arguments # will be considered eligible. unrecognized_types = [ - t for t in types if not issubclass(t, FakeTensor) and t is not torch.Tensor + t for t in types if not issubclass(t, FakeTensor) and t is not Tensor ] if unrecognized_types: not_implemented_log.debug( @@ -740,11 +785,13 @@ class FakeTensor(torch.Tensor): assert not fake_mode.in_kernel_invocation - with fake_mode: # type: ignore[attr-defined] + with fake_mode: return func(*args, **kwargs) @staticmethod - def _find_common_device(func, flat_args) -> Tuple[torch.device, bool]: + def _find_common_device( + func: OpOverload, flat_args: Sequence[object] + ) -> Tuple[torch.device, bool]: # Returns: (common_device, has_scalar_only_inputs) # cpu - zero-dim tensors can be called in cuda kernels, @@ -754,10 +801,10 @@ class FakeTensor(torch.Tensor): has_scalar_only_inputs = False is_cpu_zero_dim = None - def cpu_zero_dim(t): + def cpu_zero_dim(t: Tensor) -> bool: return t.device.type == "cpu" and t.dim() == 0 - def merge_devices(t): + def merge_devices(t: object) -> None: nonlocal common_device nonlocal is_cpu_zero_dim if not isinstance(t, FakeTensor): @@ -815,9 +862,10 @@ class FakeTensor(torch.Tensor): # To avoid this, we handle the FakeTensor case by (1) specializing on the size # of the tensor to create the output Python list, and (2) creating unbacked # symints for each element of the list. - def tolist(self): + def tolist(self) -> List[SymInt]: assert self.dim() == 1, "NYI for higher dims" shape_env = self.fake_mode.shape_env + assert shape_env is not None out = [] # Specialize on the length of the list for _ in range(self.shape[0]): @@ -837,7 +885,7 @@ class TensorMetadata: dtype: torch.dtype shape: torch.Size - stride: Tuple[Any, ...] + stride: Tuple[IntLikeType, ...] device: torch.device layout: torch.layout memory_format: Optional[torch.memory_format] @@ -854,7 +902,7 @@ class TensorMetadata: sparse_dim: Optional[int] -def extract_tensor_metadata(t: torch.Tensor) -> "TensorMetadata": +def extract_tensor_metadata(t: Tensor) -> TensorMetadata: """ Extract the TensorMetadata of a tensor. """ @@ -862,6 +910,8 @@ def extract_tensor_metadata(t: torch.Tensor) -> "TensorMetadata": if is_sparse_any(t) or not t.is_contiguous(memory_format=memory_format): memory_format = None + storage_offset = t.storage_offset() + return TensorMetadata( dtype=t.dtype, shape=t.shape, @@ -869,7 +919,7 @@ def extract_tensor_metadata(t: torch.Tensor) -> "TensorMetadata": device=t.device, layout=t.layout, memory_format=memory_format, - storage_offset=t.storage_offset(), + storage_offset=storage_offset, # Only set storage_bytes for tensors that have storage (not sparse) storage_bytes=t.untyped_storage().nbytes() if not t.is_sparse else None, requires_grad=t.requires_grad, @@ -892,11 +942,13 @@ class _DispatchCacheKey(list): __slots__ = "hashvalue" # noqa: PLC0205 - def __init__(self, tup, hash=hash): + def __init__( + self, tup: Tuple[object, ...], hash: Callable[[object], int] = hash + ) -> None: self[:] = tup self.hashvalue = hash(tup) - def __hash__(self): + def __hash__(self) -> int: # type: ignore[override] return self.hashvalue @@ -953,14 +1005,18 @@ class FakeTensorMode(TorchDispatchMode): # advance the epoch so we don't reuse unbacked memos epoch: int = 0 in_kernel_invocation: bool = False + static_shapes: bool + shape_env: Optional[ShapeEnv] + _stack: Optional[str] + allow_meta: bool def __init__( self, *, - allow_fallback_kernels=True, - allow_non_fake_inputs=False, - shape_env=None, - static_shapes=None, + allow_fallback_kernels: bool = True, + allow_non_fake_inputs: bool = False, + shape_env: Optional[ShapeEnv] = None, + static_shapes: Optional[bool] = None, # TODO: This is a temporary measure, see # https://github.com/pytorch/pytorch/pull/126245#discussion_r1604185748 # We're currently solely using this to impede population of @@ -970,8 +1026,8 @@ class FakeTensorMode(TorchDispatchMode): # this by ensuring guards also get put in the graph, but this is # pending a rework of how deferred runtime asserts in export. Once # that's done, we can remove this. - export=False, - ): + export: bool = False, + ) -> None: log.debug("create_mode 0x%x", id(self)) self.allow_fallback_kernels = allow_fallback_kernels @@ -1029,10 +1085,10 @@ class FakeTensorMode(TorchDispatchMode): # If another fake mode was already active when we enter, we also stash it here. # That way when we exit, we know to re-enable the previous fake mode. self.enter_stack: List[ - Tuple[bool, Optional[TorchDispatchMode], Optional[_bool]] + Tuple[bool, Optional[TorchDispatchMode], Optional[bool]] ] = [] - self.shape_env: ShapeEnv = shape_env + self.shape_env = shape_env self._stack_trace = traceback.extract_stack() self._stack = None @@ -1053,7 +1109,7 @@ class FakeTensorMode(TorchDispatchMode): # In this case, it's insufficient to test only one FakeTensor: you need # to distinguish between our fake tensor and other fake tensors. That's # what this function does. - def is_our_fake(self, t): + def is_our_fake(self, t: object) -> TypeGuard[FakeTensor]: return isinstance(t, FakeTensor) and t.fake_mode is self # If we should avoid device init. This changes the behavior of various APIs: @@ -1062,17 +1118,23 @@ class FakeTensorMode(TorchDispatchMode): # tensors on device # (see NOTE: [torch.tensor, lift_fresh, and device movement]) @property - def avoid_device_init(self): + def avoid_device_init(self) -> bool: return not torch.cuda.is_available() @property - def stack(self): + def stack(self) -> str: if self._stack is None: self._stack = "".join(traceback.format_list(self._stack_trace)) return self._stack @count - def __torch_dispatch__(self, func, types, args=(), kwargs=None): + def __torch_dispatch__( + self, + func: OpOverload, + types: Sequence[Type], + args: Sequence[object] = (), + kwargs: Mapping[str, object] = immutable_dict(), + ) -> object: # FakeTensorMode should not be set when we're inside of it. assert ( torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is None @@ -1084,7 +1146,7 @@ class FakeTensorMode(TorchDispatchMode): raise # No-op if FakeTensorMode is already in use - def __enter__(self): + def __enter__(self) -> Self: prev_only_lift_cpu_tensors = None if self.avoid_device_init: # See NOTE: [torch.tensor, lift_fresh, and device movement] @@ -1102,7 +1164,12 @@ class FakeTensorMode(TorchDispatchMode): self.enter_stack.append((False, None, prev_only_lift_cpu_tensors)) return self - def __exit__(self, a, b, c): + def __exit__( + self, + a: Optional[Type[BaseException]], + b: Optional[BaseException], + c: Optional[TracebackType], + ) -> None: ( live, maybe_prev_fake_mode, @@ -1129,7 +1196,7 @@ class FakeTensorMode(TorchDispatchMode): ) @classmethod - def cache_clear(cls): + def cache_clear(cls) -> None: """ Clear the dispatch cache. """ @@ -1141,15 +1208,15 @@ class FakeTensorMode(TorchDispatchMode): def _cached_dispatch_impl( self, func: OpOverload, - types: Tuple[Any, ...], - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - ): + types: Sequence[Type], + args: Sequence[object], + kwargs: Mapping[str, object], + ) -> object: """ Lookup a cache entry for the given arguments. If none exists, dispatch and cache the result (if the result is eligible for caching). """ - output: Union[FakeTensor, _Unassigned] = _UNASSIGNED + output: object = _UNASSIGNED try: key = self._cache_key(func, args, kwargs) entry = FakeTensorMode.cache.get(key, None) @@ -1177,8 +1244,8 @@ class FakeTensorMode(TorchDispatchMode): def _cache_key( self, func: OpOverload, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], + args: Sequence[object], + kwargs: Mapping[str, object], ) -> _DispatchCacheKey: """ Create a cache key given the dispatch args. Raises _BypassDispatchCache @@ -1208,9 +1275,9 @@ class FakeTensorMode(TorchDispatchMode): def _validate_cache_key( self, func: OpOverload, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - ): + args: Sequence[object], + kwargs: Mapping[str, object], + ) -> None: """ Validate that the cache key generated by _cache_key will be reasonable. @@ -1247,7 +1314,9 @@ class FakeTensorMode(TorchDispatchMode): ): raise _BypassDispatchCache("CompositeImplicitAutograd") - def _prep_args_for_hash(self, args: Any) -> Any: + def _prep_args_for_hash( + self, args: Union[Mapping[str, object], Sequence[object]] + ) -> Tuple[object, ...]: """ Translate the provided args into a form suitable for caching at FakeTensor dispatch, i.e., convert unhashable types like lists & dicts into tuples and @@ -1257,7 +1326,7 @@ class FakeTensorMode(TorchDispatchMode): if isinstance(args, dict): args = list(args.keys()) + list(args.values()) - result: List[Any] = [] + result: List[object] = [] for arg in args: if isinstance(arg, FakeTensor): if not self.is_our_fake(arg): @@ -1277,14 +1346,14 @@ class FakeTensorMode(TorchDispatchMode): # Does this subsume arg.is_sparse? raise _BypassDispatchCache("sparse tensor layout") # sparse tensors don't have storage, so check is after - if isinstance(arg.untyped_storage().nbytes(), torch.SymInt): + if isinstance(arg.untyped_storage().nbytes(), SymInt): raise _BypassDispatchCache("symbolic nbytes") if is_sparse_compressed(arg): raise _BypassDispatchCache("sparse compressed tensor") result.append(extract_tensor_metadata(arg)) - elif isinstance(arg, torch.Tensor): + elif isinstance(arg, Tensor): raise _BypassDispatchCache("non-fake tensor") - elif isinstance(arg, (torch.SymBool, torch.SymInt, torch.SymFloat)): + elif isinstance(arg, (SymBool, SymInt, SymFloat)): raise _BypassDispatchCache("symbolic shape") elif isinstance(arg, (list, tuple, dict)): result.extend(self._prep_args_for_hash(arg)) @@ -1300,15 +1369,18 @@ class FakeTensorMode(TorchDispatchMode): self, key: _DispatchCacheKey, func: OpOverload, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - output: FakeTensor, + args: Sequence[object], + kwargs: Mapping[str, object], + output: Optional[FakeTensor], ) -> _DispatchCacheEntry: """ Make a cache entry object for the given 'output' Tensor. Raises _BypassDispatchCache if the output tensor has characteristics that prevent caching it. """ + if output is None: + return _DispatchCacheEntry(inplace_idx=None, metadata=None, view_idx=None) + # Some ops return tuples of Tensors, but it's rare, so avoid # the complexity of caching other types. if not isinstance(output, FakeTensor): @@ -1342,7 +1414,7 @@ class FakeTensorMode(TorchDispatchMode): # Otherwise, create an entry that records the output tensor's metadata. view_idx = None if func.is_view: - idxs = [i for i, t in enumerate(args) if isinstance(t, torch.Tensor)] + idxs = [i for i, t in enumerate(args) if isinstance(t, Tensor)] assert len(idxs) == 1 view_idx = idxs[0] @@ -1368,18 +1440,23 @@ class FakeTensorMode(TorchDispatchMode): return entry def _output_from_cache_entry( - self, entry: _DispatchCacheEntry, func: OpOverload, args: Tuple[Any, ...] - ) -> FakeTensor: + self, entry: _DispatchCacheEntry, func: OpOverload, args: Sequence[object] + ) -> Optional[FakeTensor]: """ Create a new FakeTensor from the cache entry. """ if entry.inplace_idx is not None: # This is an in-place op; return the aliased arg. - return args[entry.inplace_idx] + inplace_arg = args[entry.inplace_idx] + assert isinstance(inplace_arg, FakeTensor) + return inplace_arg # Synthesize a new FakeTensor with the cached metadata. metadata = entry.metadata - assert metadata and not metadata.is_sparse + if metadata is None: + return None + + assert not metadata.is_sparse empty = torch.empty_strided( metadata.shape, @@ -1395,13 +1472,15 @@ class FakeTensorMode(TorchDispatchMode): if metadata.is_neg: torch._C._set_neg(empty, True) - maybe_suppress: Callable[[], Any] = contextlib.nullcontext + maybe_suppress: Callable[[], typing.ContextManager] = contextlib.nullcontext if self.shape_env is not None: maybe_suppress = self.shape_env.suppress_guards if func.is_view: # For view ops, the storage should be the same as the tensor input. - storage = args[cast(int, entry.view_idx)].untyped_storage() + view_arg = args[cast(int, entry.view_idx)] + assert isinstance(view_arg, FakeTensor) + storage = view_arg.untyped_storage() with in_kernel_invocation_manager(self), maybe_suppress(): empty.set_( storage, metadata.storage_offset, metadata.shape, metadata.stride @@ -1419,12 +1498,12 @@ class FakeTensorMode(TorchDispatchMode): def _crosscheck_cache_output( self, - output: FakeTensor, + output: Optional[FakeTensor], func: OpOverload, - types: Tuple[Any, ...], - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - ): + types: Sequence[Type], + args: Sequence[object], + kwargs: Mapping[str, object], + ) -> None: """ Helper to validate that the output synthesized from the cache matches the output created by normal dispatch. @@ -1437,14 +1516,24 @@ class FakeTensorMode(TorchDispatchMode): f"args={args}, kwargs={kwargs}: Dispatch raised={e}" ) from e try: - assert_metadata_eq(assert_eq, true_output, output) + if (true_output is not None) and (output is not None): + assert_metadata_eq(assert_eq, true_output, output) + else: + assert true_output is None + assert output is None except Exception as e: raise RuntimeError( f"FakeTensor cache crosscheck failure: func={func}, " f"args={args}, kwargs={kwargs}" ) from e - def dispatch(self, func, types, args=(), kwargs=None): + def dispatch( + self, + func: OpOverload, + types: Sequence[Type], + args: Sequence[object] = (), + kwargs: Mapping[str, object] = immutable_dict(), + ) -> object: kwargs = kwargs or {} with no_dispatch(): log.debug("%s %s %s", func, args, kwargs) @@ -1471,13 +1560,19 @@ class FakeTensorMode(TorchDispatchMode): else: return self._dispatch_impl(func, types, args, kwargs) - def _dispatch_impl(self, func, types, args, kwargs) -> FakeTensor: + def _dispatch_impl( + self, + func: OpOverload, + types: Sequence[Type], + args: Sequence[object], + kwargs: Mapping[str, object], + ) -> Optional[FakeTensor]: flat_args, args_spec = pytree.tree_flatten((args, kwargs)) flat_arg_fake_tensors = [t for t in flat_args if self.is_our_fake(t)] has_symbolic_sizes = any( i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors - ) or any(isinstance(a, torch.SymInt) for a in flat_args) + ) or any(isinstance(a, SymInt) for a in flat_args) converter = self.fake_tensor_converter @@ -1504,7 +1599,7 @@ class FakeTensorMode(TorchDispatchMode): ] const_args, const_kwargs = pytree.tree_unflatten(const_flat_args, args_spec) out = func(*const_args, **const_kwargs) - if type(out) is torch.Tensor and self.may_turn_const(out): + if type(out) is Tensor and self.may_turn_const(out): # NB: not in_kernel_invocation_manager because we're doing real # compute here # NB: no_dispatch() here is VERY DANGEROUS (like, segfault @@ -1538,7 +1633,7 @@ class FakeTensorMode(TorchDispatchMode): if is_lift_func: assert len(kwargs) == 0 and len(args) == 1, f"{args} {kwargs}" - if type(args[0]) is torch.Tensor: + if type(args[0]) is Tensor: return converter.from_real_tensor(self, args[0]) # If we are trying to avoid device init, then we need to avoid constant @@ -1591,12 +1686,12 @@ class FakeTensorMode(TorchDispatchMode): out = func(*const_args, **const_kwargs) flat_out = pytree.tree_leaves(out) - flat_out_tensors = [t for t in flat_out if isinstance(t, torch.Tensor)] + flat_out_tensors = [t for t in flat_out if isinstance(t, Tensor)] all_constant = all(self.may_turn_const(t) for t in flat_out_tensors) if all_constant: return pytree.tree_map_only( - torch.Tensor, + Tensor, lambda t: converter.from_real_tensor(self, t, make_constant=True), out, ) @@ -1611,10 +1706,11 @@ class FakeTensorMode(TorchDispatchMode): args, kwargs = pytree.tree_unflatten(flat_args, args_spec) self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs) - def maybe_to_real_tensor(t): + def maybe_to_real_tensor(t: T) -> Optional[Union[T, Tensor]]: if isinstance(t, FakeTensor): return t.real_tensor elif isinstance(t, SymTypes): + assert self.shape_env is not None return t.node.pytype( t.node.expr.xreplace(self.shape_env.var_to_val).xreplace( self.shape_env.unbacked_var_to_val @@ -1638,8 +1734,9 @@ class FakeTensorMode(TorchDispatchMode): # TODO: Handle SymFloat/SymBool and not any( ( - isinstance(a, torch.SymInt) + isinstance(a, SymInt) and (syms := free_unbacked_symbols(a)) + and self.shape_env is not None and any(s not in self.shape_env.unbacked_var_to_val for s in syms) ) for a in flat_args @@ -1663,15 +1760,16 @@ class FakeTensorMode(TorchDispatchMode): self.shape_env.unbacked_var_to_val if self.shape_env else None, ) - def maybe_propagate_real_tensors(fake_out): + def maybe_propagate_real_tensors(fake_out: T) -> T: import sympy - def go(t, real_t): + def go(t: object, real_t: Tensor) -> None: if isinstance(t, FakeTensor): # NB: unconditionally overwrite t.real_tensor = real_t elif isinstance(t, SymTypes) and free_unbacked_symbols(t): if isinstance(t.node.expr, sympy.Symbol): + assert self.shape_env is not None self.shape_env.set_unbacked_var_to_val(t.node.expr, real_t) if real_out is not nil: @@ -1751,7 +1849,9 @@ class FakeTensorMode(TorchDispatchMode): if op_impl_out is not NotImplemented: return maybe_propagate_real_tensors(op_impl_out) - def maybe_run_unsafe_fallback(error=None): + def maybe_run_unsafe_fallback( + error: Optional[RuntimeError] = None, + ) -> Optional[FakeTensor]: # We infer the meta of a custom ops that return None to just # return None. custom ops are not allowed to mutate metadata # of their inputs, so this is safe. @@ -1767,7 +1867,8 @@ class FakeTensorMode(TorchDispatchMode): # Optimization: If there is no Meta kernel, it takes a surprisingly long # amount of time to catch the NotImplementedError, so we check it here. if not has_meta(func): - return maybe_propagate_real_tensors(maybe_run_unsafe_fallback()) + fallback = maybe_run_unsafe_fallback() + return maybe_propagate_real_tensors(fallback) # run kernel registered to meta for func, which include # python meta registrations, prims, decomps, and c++ meta fns (structured kernels) @@ -1803,7 +1904,7 @@ class FakeTensorMode(TorchDispatchMode): "quantized", ) - def can_run_unsafe_fallback(self, func: OpOverload): + def can_run_unsafe_fallback(self, func: OpOverload) -> bool: if not self.allow_fallback_kernels: return False # It's OK to try the fallback for built-in ops (e.g. aten, prims) @@ -1815,17 +1916,21 @@ class FakeTensorMode(TorchDispatchMode): ) def validate_and_convert_non_fake_tensors( - self, func, converter, flat_args, args_spec - ): + self, + func: OpOverload, + converter: FakeTensorConverter, + flat_args: Sequence[object], + args_spec: TreeSpec, + ) -> Tuple[List[object], List[FakeTensor]]: """ Checks if the list of tensors are fake tensors. If not, try to convert them to fake tensors. Returns the original args, kwargs, and a flattened list of (args, kwargs) that are fake tensors. """ - flat_arg_fake_tensors: List[Any] = [] + flat_arg_fake_tensors: List[FakeTensor] = [] - def validate(x): - if not isinstance(x, torch.Tensor): + def validate(x: T) -> Union[T, FakeTensor]: + if not isinstance(x, Tensor): return x nonlocal flat_arg_fake_tensors @@ -1844,26 +1949,34 @@ class FakeTensorMode(TorchDispatchMode): f"with 'allow_non_fake_inputs'. Found in {render_call(func, args, kwargs)}" ) - x = converter.from_real_tensor(self, x) + out = converter.from_real_tensor(self, x) + else: + out = x - flat_arg_fake_tensors.append(x) - return x + flat_arg_fake_tensors.append(out) + return out validated_args = [validate(a) for a in flat_args] return validated_args, flat_arg_fake_tensors - def wrap_meta_outputs_with_default_device_logic(self, r, func, flat_args, device): + def wrap_meta_outputs_with_default_device_logic( + self, + r: object, + func: OpOverload, + flat_args: Sequence[object], + device: torch.device, + ) -> PyTree: converter = self.fake_tensor_converter # Lazily initialized, in case there are no tensor returns common_device = None has_scalar_only_inputs = False - def wrap(e): + def wrap(e: T) -> Union[T, FakeTensor]: nonlocal common_device nonlocal has_scalar_only_inputs - if not isinstance(e, torch.Tensor): + if not isinstance(e, Tensor): return e if common_device is None: @@ -1878,7 +1991,7 @@ class FakeTensorMode(TorchDispatchMode): e.device == common_device, lambda: f"FakeTensor is wrapped to wrong device, found {e.device}, expected {common_device}", ) - return e + return cast(T, e) elif converter is not None: if has_scalar_only_inputs: # Under FakeTensorMode, op accepts scalar only inputs, such as aten.add/sub/mul/div, @@ -1908,14 +2021,14 @@ class FakeTensorMode(TorchDispatchMode): aten._sparse_coo_tensor_with_dims_and_tensors.default, ) - def cpp_meta_supports_symint(self, func): + def cpp_meta_supports_symint(self, func: OpOverload) -> bool: if torch.Tag.view_copy in func.tags: return True return func in self._cpp_meta_supports_symint lift_fns = ordered_set(aten.lift_fresh.default, aten.lift_fresh_copy.default) - def may_turn_const(self, t): + def may_turn_const(self, t: Tensor) -> bool: return ( t.numel() <= CONSTANT_NUMEL_LIMIT and not t.is_sparse @@ -1924,8 +2037,12 @@ class FakeTensorMode(TorchDispatchMode): ) def invalidate_written_to_constants( - self, func, flat_arg_fake_tensors, args, kwargs - ): + self, + func: OpOverload, + flat_arg_fake_tensors: Sequence[FakeTensor], + args: Sequence[object], + kwargs: Mapping[str, object], + ) -> None: any_constant = any(e.constant is not None for e in flat_arg_fake_tensors) schema_info = get_schema_info(func) if any_constant and schema_info.is_mutable(): @@ -1943,13 +2060,13 @@ class FakeTensorMode(TorchDispatchMode): def from_tensor( self, - tensor, + tensor: Tensor, *, - static_shapes=None, + static_shapes: Optional[bool] = None, source: Optional[Source] = None, - symbolic_context=None, - trace=True, - ): + symbolic_context: Optional[SymbolicContext] = None, + trace: bool = True, + ) -> FakeTensor: shape_env: Optional[ShapeEnv] = self.shape_env if static_shapes is None: static_shapes = self.static_shapes @@ -1968,10 +2085,17 @@ class FakeTensorMode(TorchDispatchMode): ) +_StoragePointer = object + + # NB: returns fake tensors def run_fallback_kernel( - fake_mode, func, flat_args, args_spec, orig_not_implemented_exception -): + fake_mode: FakeTensorMode, + func: OpOverload, + flat_args: Sequence[object], + args_spec: PyTree, + orig_not_implemented_exception: RuntimeError, +) -> FakeTensor: # these should all be supported, just to be safe # avoid fallback for operators which inplace modify metadata # because the input fake tensors would be umodified @@ -1984,7 +2108,7 @@ def run_fallback_kernel( # REAL compute (not with meta device) with no_dispatch(): - def to_real_tensor(e): + def to_real_tensor(e: T) -> Union[T, Tensor]: if fake_mode.is_our_fake(e): out = torch.zeros_like(e, device=e.fake_device) if e.is_sparse: @@ -1998,11 +2122,10 @@ def run_fallback_kernel( r = func(*args, **kwargs) - tensor_impls = set() - storages = set() + storages: Set[_StoragePointer] = set() for e in flat_args: - if isinstance(e, torch.Tensor): + if isinstance(e, Tensor): if not e.is_sparse: storages.add(e._typed_storage()._cdata) @@ -2011,15 +2134,15 @@ def run_fallback_kernel( # not be set up, bc of conversion to device, unless we can reuse an # input impl - def map_out(e): + def map_out(e: T) -> Union[T, FakeTensor]: if id(e) not in inp_impls and ( - isinstance(e, torch.Tensor) + isinstance(e, Tensor) and not e.is_sparse and e._typed_storage()._cdata in storages ): raise orig_not_implemented_exception - if isinstance(e, torch.Tensor): + if isinstance(e, Tensor): if id(e) in inp_impls: return inp_impls[id(e)] else: @@ -2033,20 +2156,28 @@ def run_fallback_kernel( # Just for use to allow copying a module to fake tensors, # does not apply elsewhere class FakeCopyMode(TorchFunctionMode): - def __init__(self, fake_mode): + def __init__(self, fake_mode: FakeTensorMode) -> None: self.fake_mode = fake_mode - def __torch_function__(self, func, types, args=(), kwargs=None): + def __torch_function__( + self, + func: OpOverload, + types: Sequence[Type], + args: Sequence[object] = (), + kwargs: Optional[Mapping[str, object]] = None, + ) -> FakeTensor: kwargs = kwargs if kwargs else {} # clone will get called in Parameter deepcopy if func == torch._C.TensorBase.clone: + assert isinstance(args[0], Tensor) return func( self.fake_mode.from_tensor(args[0], static_shapes=True), **kwargs ) - elif func == torch.Tensor.__deepcopy__: + elif func == Tensor.__deepcopy__: assert len(args) == 2 and len(kwargs) == 0 - tensor, memo = args + tensor = cast(Tensor, args[0]) + memo = cast(Dict[int, FakeTensor], args[1]) if id(tensor) in memo: return memo[id(tensor)] @@ -2059,7 +2190,7 @@ class FakeCopyMode(TorchFunctionMode): return func(*args, **kwargs) -def _device_handler(args): +def _device_handler(args: Sequence[object]) -> torch.device: # NB: Don't use is_our_fake, just serve the fake information # as is. Notice we don't use 'self'; we use args[0].fake_mode # because they may not be the same. It would also be possible @@ -2083,24 +2214,30 @@ def _device_handler(args): # fake tensor is not supported. What we actually wanted to happen # was to give the subclass a chance to figure out what it wants to # before erroring out. Returning NotImplemented here allows this. -def _check_for_subclass(flat_args): +def _check_for_subclass(flat_args: Sequence[object]) -> bool: return any(_check_for_subclass_arg(x) for x in flat_args) -def _check_for_subclass_arg(x): +def _check_for_subclass_arg(x: object) -> bool: return ( not isinstance(x, FakeTensor) - and isinstance(x, torch.Tensor) - and type(x) is not torch.Tensor + and isinstance(x, Tensor) + and type(x) is not Tensor and type(x) is not torch.nn.Parameter ) _DISPATCH_META_HANDLERS = { torch.ops.prim.device.default: _device_handler, - torch.ops.aten.size.default: lambda args: tuple(int(s) for s in args[0].size()), - torch.ops.aten.stride.default: lambda args: tuple(int(s) for s in args[0].stride()), - torch.ops.aten.storage_offset.default: lambda args: int(args[0].storage_offset()), + torch.ops.aten.size.default: lambda args: tuple( + int(s) for s in cast(Tensor, args[0]).size() + ), + torch.ops.aten.stride.default: lambda args: tuple( + int(s) for s in cast(Tensor, args[0]).stride() + ), + torch.ops.aten.storage_offset.default: lambda args: int( + cast(Tensor, args[0]).storage_offset() + ), } _DISPATCH_HANDLE_DIRECTLY = ordered_set( diff --git a/torch/distributed/_composable/fsdp/_fsdp_init.py b/torch/distributed/_composable/fsdp/_fsdp_init.py index 141addc6b71..a3a11a1a1c9 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_init.py +++ b/torch/distributed/_composable/fsdp/_fsdp_init.py @@ -140,7 +140,8 @@ def _move_states_to_device( raise AssertionError( f"Expects DTensor to be moved to {dtensor_mesh_type} but got {tensor.device}" ) - if is_traceable_wrapper_subclass(tensor): + tensor_ = tensor + if is_traceable_wrapper_subclass(tensor_): with torch.no_grad(): # avoid autograd increasing C++ refcount by 1 tensor_on_device = nn.Parameter(tensor.to(device)) torch.utils.swap_tensors(tensor, tensor_on_device) diff --git a/torch/export/_trace.py b/torch/export/_trace.py index ce7b5d79027..9cbc8aa8087 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -1700,6 +1700,7 @@ def _export_for_training( # The unbacked symint symbols are updated in aot_export # so we serialize them here instead of inside dynamo. + assert fake_mode.shape_env is not None gm.meta["inline_constraints"] = { k: v for k, v in fake_mode.shape_env.var_to_range.items() @@ -1884,6 +1885,7 @@ def _export( # The unbacked symint symbols are updated in aot_export # so we serialize them here instead of inside dynamo. + assert fake_mode.shape_env is not None gm.meta["inline_constraints"] = { k: v for k, v in fake_mode.shape_env.var_to_range.items() diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 41a64b9b5e7..402d218866a 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -1649,6 +1649,7 @@ class _MakefxTracer: return self.fake_tensor_mode.from_tensor(x, source=source) # NB: don't match on bools elif type(x) is int and self.tracing_mode == "symbolic": + assert self.fake_tensor_mode.shape_env is not None, "shape_env should be set if tracing with 'symbolic'" return self.fake_tensor_mode.shape_env.create_symintnode( self.fake_tensor_mode.shape_env.create_symbol(x, source, positive=None), hint=x, diff --git a/torch/types.py b/torch/types.py index 7362feb149e..89ba6df550f 100644 --- a/torch/types.py +++ b/torch/types.py @@ -17,6 +17,7 @@ from builtins import ( # noqa: F401 from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union import torch +from torch import SymInt if TYPE_CHECKING: @@ -40,6 +41,7 @@ _device = torch.device _qscheme = torch.qscheme _layout = torch.layout _size = Union[torch.Size, List[builtins.int], Tuple[builtins.int, ...]] +_symsize = Union[torch.Size, Sequence[Union[_int, SymInt]]] _dispatchkey = Union[builtins.str, torch._C.DispatchKey] # int or SymInt diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index f48c27491a4..3d669eb26b2 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -3,7 +3,7 @@ import contextlib import warnings from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Set, Union, Protocol, Sequence, Tuple, overload +from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload from typing_extensions import TypeGuard import torch