mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
typing fake_tensor.py (#128041)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128041 Approved by: https://github.com/eellison ghstack dependencies: #129182
This commit is contained in:
parent
1ad0f38a37
commit
567482973d
14 changed files with 397 additions and 223 deletions
|
|
@ -1204,7 +1204,7 @@ def gen_pyi(
|
||||||
],
|
],
|
||||||
"set_": [
|
"set_": [
|
||||||
"def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage], "
|
"def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage], "
|
||||||
"offset: _int, size: _size, stride: _size) -> Tensor: ...",
|
"offset: _int, size: _symsize, stride: _symsize) -> Tensor: ...",
|
||||||
"def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage]) -> Tensor: ...",
|
"def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage]) -> Tensor: ...",
|
||||||
],
|
],
|
||||||
"split": [
|
"split": [
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,7 @@ from torch.types import (
|
||||||
_qscheme,
|
_qscheme,
|
||||||
_size,
|
_size,
|
||||||
_str,
|
_str,
|
||||||
|
_symsize,
|
||||||
)
|
)
|
||||||
from torch.utils._python_dispatch import TorchDispatchMode
|
from torch.utils._python_dispatch import TorchDispatchMode
|
||||||
|
|
||||||
|
|
@ -1661,6 +1662,18 @@ class _SetExcludeDispatchKeyGuard:
|
||||||
def __enter__(self): ...
|
def __enter__(self): ...
|
||||||
def __exit__(self, exc_type, exc_value, traceback): ...
|
def __exit__(self, exc_type, exc_value, traceback): ...
|
||||||
|
|
||||||
|
# Defined in torch/csrc/utils/schema_info.h
|
||||||
|
|
||||||
|
class _SchemaInfo:
|
||||||
|
def __init__(self, schema: _int) -> None: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def is_mutable(self) -> _bool: ...
|
||||||
|
@overload
|
||||||
|
def is_mutable(self, name: str) -> _bool: ...
|
||||||
|
|
||||||
|
def has_argument(self, name: str) -> _bool: ...
|
||||||
|
|
||||||
# Defined in torch/csrc/utils/init.cpp
|
# Defined in torch/csrc/utils/init.cpp
|
||||||
class BenchmarkConfig:
|
class BenchmarkConfig:
|
||||||
num_calling_threads: _int
|
num_calling_threads: _int
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,9 @@ from typing import (
|
||||||
)
|
)
|
||||||
from typing_extensions import ParamSpec as _ParamSpec, TypeGuard as _TypeGuard
|
from typing_extensions import ParamSpec as _ParamSpec, TypeGuard as _TypeGuard
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .types import IntLikeType
|
||||||
|
|
||||||
|
|
||||||
# multipy/deploy is setting this import before importing torch, this is the most
|
# multipy/deploy is setting this import before importing torch, this is the most
|
||||||
# reliable way we have to detect if we're running within deploy.
|
# reliable way we have to detect if we're running within deploy.
|
||||||
|
|
@ -471,6 +474,9 @@ class SymInt:
|
||||||
def __add__(self, other) -> "SymInt":
|
def __add__(self, other) -> "SymInt":
|
||||||
raise TypeError("type stub not overridden")
|
raise TypeError("type stub not overridden")
|
||||||
|
|
||||||
|
def __mod__(self, other: "IntLikeType") -> "SymInt":
|
||||||
|
raise TypeError("type stub not overridden")
|
||||||
|
|
||||||
def __mul__(self, other) -> "SymInt":
|
def __mul__(self, other) -> "SymInt":
|
||||||
raise TypeError("type stub not overridden")
|
raise TypeError("type stub not overridden")
|
||||||
|
|
||||||
|
|
@ -504,6 +510,9 @@ class SymInt:
|
||||||
def __neg__(self):
|
def __neg__(self):
|
||||||
raise TypeError("type stub not overridden")
|
raise TypeError("type stub not overridden")
|
||||||
|
|
||||||
|
def __sub__(self, other: "IntLikeType") -> "SymInt":
|
||||||
|
raise TypeError("type stub not overridden")
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self.node._graph_repr()
|
return self.node._graph_repr()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -165,6 +165,7 @@ def _apply_func_to_inner_tensors_of_same_dim(func, t, *args, **kwargs):
|
||||||
assert is_traceable_wrapper_subclass(t)
|
assert is_traceable_wrapper_subclass(t)
|
||||||
|
|
||||||
attrs, ctx = t.__tensor_flatten__()
|
attrs, ctx = t.__tensor_flatten__()
|
||||||
|
assert isinstance(t, torch.Tensor)
|
||||||
for attr in attrs:
|
for attr in attrs:
|
||||||
inner = getattr(t, attr)
|
inner = getattr(t, attr)
|
||||||
if inner.dim() == t.dim():
|
if inner.dim() == t.dim():
|
||||||
|
|
|
||||||
|
|
@ -83,6 +83,7 @@ def fakify(
|
||||||
constraint_sizes=[None] * n_dims,
|
constraint_sizes=[None] * n_dims,
|
||||||
)
|
)
|
||||||
t_id = id(t)
|
t_id = id(t)
|
||||||
|
assert mode.shape_env is not None
|
||||||
if t_id in t_constraints:
|
if t_id in t_constraints:
|
||||||
for i, constraint in t_constraints[t_id].items():
|
for i, constraint in t_constraints[t_id].items():
|
||||||
symbolic_context.constraint_sizes[i] = constraint.constraint_range
|
symbolic_context.constraint_sizes[i] = constraint.constraint_range
|
||||||
|
|
@ -256,6 +257,7 @@ def produce_guards_and_solve_constraints(
|
||||||
_disable_forced_specializations: if True, avoids forced specializations
|
_disable_forced_specializations: if True, avoids forced specializations
|
||||||
"""
|
"""
|
||||||
shape_env = fake_mode.shape_env
|
shape_env = fake_mode.shape_env
|
||||||
|
assert shape_env is not None
|
||||||
assert shape_env.tracked_fakes is not None
|
assert shape_env.tracked_fakes is not None
|
||||||
|
|
||||||
placeholders = [tf.fake for tf in shape_env.tracked_fakes]
|
placeholders = [tf.fake for tf in shape_env.tracked_fakes]
|
||||||
|
|
@ -322,6 +324,7 @@ def make_constraints(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
shape_env = fake_mode.shape_env
|
shape_env = fake_mode.shape_env
|
||||||
|
assert shape_env is not None
|
||||||
inline_constraints = gm.meta.get("inline_constraints", [])
|
inline_constraints = gm.meta.get("inline_constraints", [])
|
||||||
range_constraints = {
|
range_constraints = {
|
||||||
symbol: inline_constraints[symbol] for symbol in inline_constraints
|
symbol: inline_constraints[symbol] for symbol in inline_constraints
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ import pprint
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.dlpack
|
import torch.utils.dlpack
|
||||||
|
|
@ -1450,7 +1450,7 @@ Expected metadata: {str(expected_tangent_metadata)}
|
||||||
|
|
||||||
Runtime metadata: {str(runtime_tangent_metadata)}
|
Runtime metadata: {str(runtime_tangent_metadata)}
|
||||||
|
|
||||||
shape: {str(x.shape)}
|
shape: {str(cast(torch.Tensor, x).shape)}
|
||||||
To fix this, your tensor subclass must implement the dunder method __force_to_same_metadata__.
|
To fix this, your tensor subclass must implement the dunder method __force_to_same_metadata__.
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
@ -1830,14 +1830,16 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||||
)
|
)
|
||||||
assert CompiledFunction.metadata.traced_tangent_metas is not None
|
assert CompiledFunction.metadata.traced_tangent_metas is not None
|
||||||
all_args = [
|
all_args = [
|
||||||
AOTDispatchAutograd.coerce_runtime_tangent(
|
(
|
||||||
t,
|
AOTDispatchAutograd.coerce_runtime_tangent(
|
||||||
CompiledFunction.metadata.traced_tangent_metas[
|
t,
|
||||||
i - tangents_start_idx
|
CompiledFunction.metadata.traced_tangent_metas[
|
||||||
],
|
i - tangents_start_idx
|
||||||
|
],
|
||||||
|
)
|
||||||
|
if tangents_start_idx <= i < tangents_end_idx
|
||||||
|
else t
|
||||||
)
|
)
|
||||||
if tangents_start_idx <= i < tangents_end_idx
|
|
||||||
else t
|
|
||||||
for i, t in enumerate(all_args)
|
for i, t in enumerate(all_args)
|
||||||
]
|
]
|
||||||
all_args = unwrap_tensor_subclasses(
|
all_args = unwrap_tensor_subclasses(
|
||||||
|
|
@ -1849,9 +1851,11 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||||
# Make the tangents contiguous. Note that we must do this after subclass desugaring
|
# Make the tangents contiguous. Note that we must do this after subclass desugaring
|
||||||
# because inputs to inductor have to be contiguous
|
# because inputs to inductor have to be contiguous
|
||||||
all_args = [
|
all_args = [
|
||||||
AOTDispatchAutograd._force_contiguous(t)
|
(
|
||||||
if (tangents_start_idx <= i < tangents_end_idx)
|
AOTDispatchAutograd._force_contiguous(t)
|
||||||
else t
|
if (tangents_start_idx <= i < tangents_end_idx)
|
||||||
|
else t
|
||||||
|
)
|
||||||
for i, t in enumerate(all_args)
|
for i, t in enumerate(all_args)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ AOTAutograd's responsibility is to trace through all pytorch capabilities that l
|
||||||
and this includes tensor subclasses that implement __torch_dispatch__.
|
and this includes tensor subclasses that implement __torch_dispatch__.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import typing
|
||||||
from typing import Any, List, Optional, Tuple, Union
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
|
|
@ -115,7 +116,7 @@ def unwrap_tensor_subclasses(wrapped_args, *, is_joint_structure: bool):
|
||||||
xs_inner = []
|
xs_inner = []
|
||||||
for x in xs:
|
for x in xs:
|
||||||
if is_traceable_wrapper_subclass(x):
|
if is_traceable_wrapper_subclass(x):
|
||||||
xs_inner.extend(get_plain_tensors(x))
|
xs_inner.extend(get_plain_tensors(typing.cast(Tensor, x)))
|
||||||
else:
|
else:
|
||||||
xs_inner.append(x)
|
xs_inner.append(x)
|
||||||
return xs_inner
|
return xs_inner
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from typing import TYPE_CHECKING
|
||||||
functionalize_rng_ops = False
|
functionalize_rng_ops = False
|
||||||
|
|
||||||
# can be useful for debugging if we are incorrectly creating meta fake tensors
|
# can be useful for debugging if we are incorrectly creating meta fake tensors
|
||||||
fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", True)
|
fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", "1") != "0"
|
||||||
|
|
||||||
# Enables optional asserts in hotpath code to check for errors. If
|
# Enables optional asserts in hotpath code to check for errors. If
|
||||||
# you are seeing weird accuracy problems, try turning this on.
|
# you are seeing weird accuracy problems, try turning this on.
|
||||||
|
|
@ -24,7 +24,7 @@ fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", True)
|
||||||
# but it is on by default for aot_eager.
|
# but it is on by default for aot_eager.
|
||||||
debug_assert = False
|
debug_assert = False
|
||||||
|
|
||||||
debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", False)
|
debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", "0") != "0"
|
||||||
|
|
||||||
# Today, if you are in a situation where there is "false aliasing"
|
# Today, if you are in a situation where there is "false aliasing"
|
||||||
# (e.g. you have a bunch of model parameters that all alias the same underlying buffer),
|
# (e.g. you have a bunch of model parameters that all alias the same underlying buffer),
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -140,7 +140,8 @@ def _move_states_to_device(
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
f"Expects DTensor to be moved to {dtensor_mesh_type} but got {tensor.device}"
|
f"Expects DTensor to be moved to {dtensor_mesh_type} but got {tensor.device}"
|
||||||
)
|
)
|
||||||
if is_traceable_wrapper_subclass(tensor):
|
tensor_ = tensor
|
||||||
|
if is_traceable_wrapper_subclass(tensor_):
|
||||||
with torch.no_grad(): # avoid autograd increasing C++ refcount by 1
|
with torch.no_grad(): # avoid autograd increasing C++ refcount by 1
|
||||||
tensor_on_device = nn.Parameter(tensor.to(device))
|
tensor_on_device = nn.Parameter(tensor.to(device))
|
||||||
torch.utils.swap_tensors(tensor, tensor_on_device)
|
torch.utils.swap_tensors(tensor, tensor_on_device)
|
||||||
|
|
|
||||||
|
|
@ -1700,6 +1700,7 @@ def _export_for_training(
|
||||||
|
|
||||||
# The unbacked symint symbols are updated in aot_export
|
# The unbacked symint symbols are updated in aot_export
|
||||||
# so we serialize them here instead of inside dynamo.
|
# so we serialize them here instead of inside dynamo.
|
||||||
|
assert fake_mode.shape_env is not None
|
||||||
gm.meta["inline_constraints"] = {
|
gm.meta["inline_constraints"] = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in fake_mode.shape_env.var_to_range.items()
|
for k, v in fake_mode.shape_env.var_to_range.items()
|
||||||
|
|
@ -1884,6 +1885,7 @@ def _export(
|
||||||
|
|
||||||
# The unbacked symint symbols are updated in aot_export
|
# The unbacked symint symbols are updated in aot_export
|
||||||
# so we serialize them here instead of inside dynamo.
|
# so we serialize them here instead of inside dynamo.
|
||||||
|
assert fake_mode.shape_env is not None
|
||||||
gm.meta["inline_constraints"] = {
|
gm.meta["inline_constraints"] = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in fake_mode.shape_env.var_to_range.items()
|
for k, v in fake_mode.shape_env.var_to_range.items()
|
||||||
|
|
|
||||||
|
|
@ -1649,6 +1649,7 @@ class _MakefxTracer:
|
||||||
return self.fake_tensor_mode.from_tensor(x, source=source)
|
return self.fake_tensor_mode.from_tensor(x, source=source)
|
||||||
# NB: don't match on bools
|
# NB: don't match on bools
|
||||||
elif type(x) is int and self.tracing_mode == "symbolic":
|
elif type(x) is int and self.tracing_mode == "symbolic":
|
||||||
|
assert self.fake_tensor_mode.shape_env is not None, "shape_env should be set if tracing with 'symbolic'"
|
||||||
return self.fake_tensor_mode.shape_env.create_symintnode(
|
return self.fake_tensor_mode.shape_env.create_symintnode(
|
||||||
self.fake_tensor_mode.shape_env.create_symbol(x, source, positive=None),
|
self.fake_tensor_mode.shape_env.create_symbol(x, source, positive=None),
|
||||||
hint=x,
|
hint=x,
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ from builtins import ( # noqa: F401
|
||||||
from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import SymInt
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -40,6 +41,7 @@ _device = torch.device
|
||||||
_qscheme = torch.qscheme
|
_qscheme = torch.qscheme
|
||||||
_layout = torch.layout
|
_layout = torch.layout
|
||||||
_size = Union[torch.Size, List[builtins.int], Tuple[builtins.int, ...]]
|
_size = Union[torch.Size, List[builtins.int], Tuple[builtins.int, ...]]
|
||||||
|
_symsize = Union[torch.Size, Sequence[Union[_int, SymInt]]]
|
||||||
_dispatchkey = Union[builtins.str, torch._C.DispatchKey]
|
_dispatchkey = Union[builtins.str, torch._C.DispatchKey]
|
||||||
|
|
||||||
# int or SymInt
|
# int or SymInt
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import contextlib
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional, Set, Union, Protocol, Sequence, Tuple, overload
|
from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload
|
||||||
from typing_extensions import TypeGuard
|
from typing_extensions import TypeGuard
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue