typing fake_tensor.py (#128041)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128041
Approved by: https://github.com/eellison
ghstack dependencies: #129182
This commit is contained in:
Aaron Orenstein 2024-07-12 08:19:16 -07:00 committed by PyTorch MergeBot
parent 1ad0f38a37
commit 567482973d
14 changed files with 397 additions and 223 deletions

View file

@ -1204,7 +1204,7 @@ def gen_pyi(
], ],
"set_": [ "set_": [
"def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage], " "def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage], "
"offset: _int, size: _size, stride: _size) -> Tensor: ...", "offset: _int, size: _symsize, stride: _symsize) -> Tensor: ...",
"def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage]) -> Tensor: ...", "def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage]) -> Tensor: ...",
], ],
"split": [ "split": [

View file

@ -56,6 +56,7 @@ from torch.types import (
_qscheme, _qscheme,
_size, _size,
_str, _str,
_symsize,
) )
from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._python_dispatch import TorchDispatchMode
@ -1661,6 +1662,18 @@ class _SetExcludeDispatchKeyGuard:
def __enter__(self): ... def __enter__(self): ...
def __exit__(self, exc_type, exc_value, traceback): ... def __exit__(self, exc_type, exc_value, traceback): ...
# Defined in torch/csrc/utils/schema_info.h
class _SchemaInfo:
def __init__(self, schema: _int) -> None: ...
@overload
def is_mutable(self) -> _bool: ...
@overload
def is_mutable(self, name: str) -> _bool: ...
def has_argument(self, name: str) -> _bool: ...
# Defined in torch/csrc/utils/init.cpp # Defined in torch/csrc/utils/init.cpp
class BenchmarkConfig: class BenchmarkConfig:
num_calling_threads: _int num_calling_threads: _int

View file

@ -36,6 +36,9 @@ from typing import (
) )
from typing_extensions import ParamSpec as _ParamSpec, TypeGuard as _TypeGuard from typing_extensions import ParamSpec as _ParamSpec, TypeGuard as _TypeGuard
if TYPE_CHECKING:
from .types import IntLikeType
# multipy/deploy is setting this import before importing torch, this is the most # multipy/deploy is setting this import before importing torch, this is the most
# reliable way we have to detect if we're running within deploy. # reliable way we have to detect if we're running within deploy.
@ -471,6 +474,9 @@ class SymInt:
def __add__(self, other) -> "SymInt": def __add__(self, other) -> "SymInt":
raise TypeError("type stub not overridden") raise TypeError("type stub not overridden")
def __mod__(self, other: "IntLikeType") -> "SymInt":
raise TypeError("type stub not overridden")
def __mul__(self, other) -> "SymInt": def __mul__(self, other) -> "SymInt":
raise TypeError("type stub not overridden") raise TypeError("type stub not overridden")
@ -504,6 +510,9 @@ class SymInt:
def __neg__(self): def __neg__(self):
raise TypeError("type stub not overridden") raise TypeError("type stub not overridden")
def __sub__(self, other: "IntLikeType") -> "SymInt":
raise TypeError("type stub not overridden")
def __repr__(self): def __repr__(self):
return self.node._graph_repr() return self.node._graph_repr()

View file

@ -165,6 +165,7 @@ def _apply_func_to_inner_tensors_of_same_dim(func, t, *args, **kwargs):
assert is_traceable_wrapper_subclass(t) assert is_traceable_wrapper_subclass(t)
attrs, ctx = t.__tensor_flatten__() attrs, ctx = t.__tensor_flatten__()
assert isinstance(t, torch.Tensor)
for attr in attrs: for attr in attrs:
inner = getattr(t, attr) inner = getattr(t, attr)
if inner.dim() == t.dim(): if inner.dim() == t.dim():

View file

@ -83,6 +83,7 @@ def fakify(
constraint_sizes=[None] * n_dims, constraint_sizes=[None] * n_dims,
) )
t_id = id(t) t_id = id(t)
assert mode.shape_env is not None
if t_id in t_constraints: if t_id in t_constraints:
for i, constraint in t_constraints[t_id].items(): for i, constraint in t_constraints[t_id].items():
symbolic_context.constraint_sizes[i] = constraint.constraint_range symbolic_context.constraint_sizes[i] = constraint.constraint_range
@ -256,6 +257,7 @@ def produce_guards_and_solve_constraints(
_disable_forced_specializations: if True, avoids forced specializations _disable_forced_specializations: if True, avoids forced specializations
""" """
shape_env = fake_mode.shape_env shape_env = fake_mode.shape_env
assert shape_env is not None
assert shape_env.tracked_fakes is not None assert shape_env.tracked_fakes is not None
placeholders = [tf.fake for tf in shape_env.tracked_fakes] placeholders = [tf.fake for tf in shape_env.tracked_fakes]
@ -322,6 +324,7 @@ def make_constraints(
""" """
shape_env = fake_mode.shape_env shape_env = fake_mode.shape_env
assert shape_env is not None
inline_constraints = gm.meta.get("inline_constraints", []) inline_constraints = gm.meta.get("inline_constraints", [])
range_constraints = { range_constraints = {
symbol: inline_constraints[symbol] for symbol in inline_constraints symbol: inline_constraints[symbol] for symbol in inline_constraints

View file

@ -12,7 +12,7 @@ import pprint
from contextlib import nullcontext from contextlib import nullcontext
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import wraps from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.utils.dlpack import torch.utils.dlpack
@ -1450,7 +1450,7 @@ Expected metadata: {str(expected_tangent_metadata)}
Runtime metadata: {str(runtime_tangent_metadata)} Runtime metadata: {str(runtime_tangent_metadata)}
shape: {str(x.shape)} shape: {str(cast(torch.Tensor, x).shape)}
To fix this, your tensor subclass must implement the dunder method __force_to_same_metadata__. To fix this, your tensor subclass must implement the dunder method __force_to_same_metadata__.
""" """
) )
@ -1830,14 +1830,16 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
) )
assert CompiledFunction.metadata.traced_tangent_metas is not None assert CompiledFunction.metadata.traced_tangent_metas is not None
all_args = [ all_args = [
AOTDispatchAutograd.coerce_runtime_tangent( (
t, AOTDispatchAutograd.coerce_runtime_tangent(
CompiledFunction.metadata.traced_tangent_metas[ t,
i - tangents_start_idx CompiledFunction.metadata.traced_tangent_metas[
], i - tangents_start_idx
],
)
if tangents_start_idx <= i < tangents_end_idx
else t
) )
if tangents_start_idx <= i < tangents_end_idx
else t
for i, t in enumerate(all_args) for i, t in enumerate(all_args)
] ]
all_args = unwrap_tensor_subclasses( all_args = unwrap_tensor_subclasses(
@ -1849,9 +1851,11 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
# Make the tangents contiguous. Note that we must do this after subclass desugaring # Make the tangents contiguous. Note that we must do this after subclass desugaring
# because inputs to inductor have to be contiguous # because inputs to inductor have to be contiguous
all_args = [ all_args = [
AOTDispatchAutograd._force_contiguous(t) (
if (tangents_start_idx <= i < tangents_end_idx) AOTDispatchAutograd._force_contiguous(t)
else t if (tangents_start_idx <= i < tangents_end_idx)
else t
)
for i, t in enumerate(all_args) for i, t in enumerate(all_args)
] ]

View file

@ -5,6 +5,7 @@ AOTAutograd's responsibility is to trace through all pytorch capabilities that l
and this includes tensor subclasses that implement __torch_dispatch__. and this includes tensor subclasses that implement __torch_dispatch__.
""" """
import typing
from typing import Any, List, Optional, Tuple, Union from typing import Any, List, Optional, Tuple, Union
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
@ -115,7 +116,7 @@ def unwrap_tensor_subclasses(wrapped_args, *, is_joint_structure: bool):
xs_inner = [] xs_inner = []
for x in xs: for x in xs:
if is_traceable_wrapper_subclass(x): if is_traceable_wrapper_subclass(x):
xs_inner.extend(get_plain_tensors(x)) xs_inner.extend(get_plain_tensors(typing.cast(Tensor, x)))
else: else:
xs_inner.append(x) xs_inner.append(x)
return xs_inner return xs_inner

View file

@ -16,7 +16,7 @@ from typing import TYPE_CHECKING
functionalize_rng_ops = False functionalize_rng_ops = False
# can be useful for debugging if we are incorrectly creating meta fake tensors # can be useful for debugging if we are incorrectly creating meta fake tensors
fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", True) fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", "1") != "0"
# Enables optional asserts in hotpath code to check for errors. If # Enables optional asserts in hotpath code to check for errors. If
# you are seeing weird accuracy problems, try turning this on. # you are seeing weird accuracy problems, try turning this on.
@ -24,7 +24,7 @@ fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", True)
# but it is on by default for aot_eager. # but it is on by default for aot_eager.
debug_assert = False debug_assert = False
debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", False) debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", "0") != "0"
# Today, if you are in a situation where there is "false aliasing" # Today, if you are in a situation where there is "false aliasing"
# (e.g. you have a bunch of model parameters that all alias the same underlying buffer), # (e.g. you have a bunch of model parameters that all alias the same underlying buffer),

File diff suppressed because it is too large Load diff

View file

@ -140,7 +140,8 @@ def _move_states_to_device(
raise AssertionError( raise AssertionError(
f"Expects DTensor to be moved to {dtensor_mesh_type} but got {tensor.device}" f"Expects DTensor to be moved to {dtensor_mesh_type} but got {tensor.device}"
) )
if is_traceable_wrapper_subclass(tensor): tensor_ = tensor
if is_traceable_wrapper_subclass(tensor_):
with torch.no_grad(): # avoid autograd increasing C++ refcount by 1 with torch.no_grad(): # avoid autograd increasing C++ refcount by 1
tensor_on_device = nn.Parameter(tensor.to(device)) tensor_on_device = nn.Parameter(tensor.to(device))
torch.utils.swap_tensors(tensor, tensor_on_device) torch.utils.swap_tensors(tensor, tensor_on_device)

View file

@ -1700,6 +1700,7 @@ def _export_for_training(
# The unbacked symint symbols are updated in aot_export # The unbacked symint symbols are updated in aot_export
# so we serialize them here instead of inside dynamo. # so we serialize them here instead of inside dynamo.
assert fake_mode.shape_env is not None
gm.meta["inline_constraints"] = { gm.meta["inline_constraints"] = {
k: v k: v
for k, v in fake_mode.shape_env.var_to_range.items() for k, v in fake_mode.shape_env.var_to_range.items()
@ -1884,6 +1885,7 @@ def _export(
# The unbacked symint symbols are updated in aot_export # The unbacked symint symbols are updated in aot_export
# so we serialize them here instead of inside dynamo. # so we serialize them here instead of inside dynamo.
assert fake_mode.shape_env is not None
gm.meta["inline_constraints"] = { gm.meta["inline_constraints"] = {
k: v k: v
for k, v in fake_mode.shape_env.var_to_range.items() for k, v in fake_mode.shape_env.var_to_range.items()

View file

@ -1649,6 +1649,7 @@ class _MakefxTracer:
return self.fake_tensor_mode.from_tensor(x, source=source) return self.fake_tensor_mode.from_tensor(x, source=source)
# NB: don't match on bools # NB: don't match on bools
elif type(x) is int and self.tracing_mode == "symbolic": elif type(x) is int and self.tracing_mode == "symbolic":
assert self.fake_tensor_mode.shape_env is not None, "shape_env should be set if tracing with 'symbolic'"
return self.fake_tensor_mode.shape_env.create_symintnode( return self.fake_tensor_mode.shape_env.create_symintnode(
self.fake_tensor_mode.shape_env.create_symbol(x, source, positive=None), self.fake_tensor_mode.shape_env.create_symbol(x, source, positive=None),
hint=x, hint=x,

View file

@ -17,6 +17,7 @@ from builtins import ( # noqa: F401
from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
import torch import torch
from torch import SymInt
if TYPE_CHECKING: if TYPE_CHECKING:
@ -40,6 +41,7 @@ _device = torch.device
_qscheme = torch.qscheme _qscheme = torch.qscheme
_layout = torch.layout _layout = torch.layout
_size = Union[torch.Size, List[builtins.int], Tuple[builtins.int, ...]] _size = Union[torch.Size, List[builtins.int], Tuple[builtins.int, ...]]
_symsize = Union[torch.Size, Sequence[Union[_int, SymInt]]]
_dispatchkey = Union[builtins.str, torch._C.DispatchKey] _dispatchkey = Union[builtins.str, torch._C.DispatchKey]
# int or SymInt # int or SymInt

View file

@ -3,7 +3,7 @@ import contextlib
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Union, Protocol, Sequence, Tuple, overload from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload
from typing_extensions import TypeGuard from typing_extensions import TypeGuard
import torch import torch