mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Replace TensorMeta with FakeTensor
Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/78836 Approved by: https://github.com/albanD, https://github.com/mruberry
This commit is contained in:
parent
484282a6fd
commit
587efdb5fa
11 changed files with 160 additions and 336 deletions
|
|
@ -14,6 +14,7 @@ from torch.testing._internal.common_dtype import (
|
|||
floating_and_complex_types_and,
|
||||
all_types_and_complex_and,
|
||||
)
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
is_iterable_of_tensors,
|
||||
|
|
@ -356,7 +357,7 @@ class TestCommon(TestCase):
|
|||
|
||||
def _to_tensormeta(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return prims.utils.TensorMeta(x)
|
||||
return FakeTensor.from_tensor(x)
|
||||
return x
|
||||
|
||||
# TODO: iterate over requires_grad true/false
|
||||
|
|
@ -506,7 +507,7 @@ class TestCommon(TestCase):
|
|||
def test_python_ref_errors(self, device, op):
|
||||
def _to_tensormeta(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return prims.utils.TensorMeta(x)
|
||||
return FakeTensor.from_tensor(x)
|
||||
return x
|
||||
|
||||
error_inputs = op.error_inputs(device)
|
||||
|
|
|
|||
|
|
@ -840,6 +840,8 @@ class Generator(object):
|
|||
def _dispatch_library(kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0) -> Any: ...
|
||||
def _dispatch_has_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ...
|
||||
def _dispatch_has_kernel(name: str) -> _bool: ...
|
||||
def _dispatch_tls_is_dispatch_key_excluded(dispatch: str) -> _bool: ...
|
||||
def _dispatch_tls_set_dispatch_key_excluded(dispatch: str, val: _bool) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/utils/init.cpp
|
||||
class BenchmarkConfig(object):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import torch._prims.utils as utils
|
|||
from torch._prims.utils import (
|
||||
TensorLike,
|
||||
TensorLikeType,
|
||||
TensorMeta,
|
||||
ShapeType,
|
||||
getnvFuserDtype,
|
||||
DimsType,
|
||||
|
|
@ -13,11 +12,14 @@ from torch._prims.utils import (
|
|||
StrideType,
|
||||
Number,
|
||||
NumberType,
|
||||
TensorMeta,
|
||||
)
|
||||
from torch.overrides import has_torch_function, handle_torch_function
|
||||
import torch.library
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
|
||||
import contextlib
|
||||
from typing import Sequence, Optional, Union, Callable, List, Tuple, Any, Type
|
||||
from functools import reduce, partial
|
||||
from enum import Enum
|
||||
|
|
@ -26,6 +28,7 @@ import math
|
|||
|
||||
prim = torch.library.Library("prims", "DEF")
|
||||
prim_impl = torch.library.Library("prims", "IMPL", "CompositeExplicitAutograd")
|
||||
prim_autograd_impl = torch.library.Library("prims", "IMPL", "Autograd")
|
||||
prim_meta_impl = torch.library.Library("prims", "IMPL", "Meta")
|
||||
|
||||
# Experimental module containing prototype "primitive" operations.
|
||||
|
|
@ -286,27 +289,29 @@ class RETURN_TYPE(Enum):
|
|||
def _wrap_tensor_meta(f):
|
||||
def wrap(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return TensorMeta(t)
|
||||
else:
|
||||
return t
|
||||
|
||||
def unwrap(t):
|
||||
# TODO: doesn't setup aliasing relation on views correctly
|
||||
if isinstance(t, TensorMeta):
|
||||
return torch.empty_strided(
|
||||
t.shape, t.stride(), dtype=t.dtype, device="meta"
|
||||
)
|
||||
return FakeTensor.from_tensor(t)
|
||||
else:
|
||||
return t
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
wrapped_args = tree_map(wrap, args)
|
||||
wrapped_kwargs = tree_map(wrap, kwargs)
|
||||
return tree_map(unwrap, f(*wrapped_args, **wrapped_kwargs))
|
||||
return f(*wrapped_args, **wrapped_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _DispatchBelowAutograd():
|
||||
# TODO: AutogradOther
|
||||
old = torch._C._dispatch_tls_is_dispatch_key_excluded("AutogradFunctionality")
|
||||
torch._C._dispatch_tls_set_dispatch_key_excluded("AutogradFunctionality", True)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._C._dispatch_tls_set_dispatch_key_excluded("AutogradFunctionality", old)
|
||||
|
||||
|
||||
def _make_prim(
|
||||
*,
|
||||
schema: str,
|
||||
|
|
@ -330,16 +335,33 @@ def _make_prim(
|
|||
meta(*args, **kwargs)
|
||||
return impl_aten(*args, **kwargs)
|
||||
|
||||
class BackwardsNotSupported(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, args_spec, *flat_args):
|
||||
args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type]
|
||||
with _DispatchBelowAutograd():
|
||||
return _prim(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
raise RuntimeError("backwards not supported on prim")
|
||||
|
||||
def _autograd_impl(*args, **kwargs):
|
||||
flat_args, args_spec = tree_flatten((args, kwargs))
|
||||
return BackwardsNotSupported.apply(args_spec, *flat_args)
|
||||
|
||||
name = schema.split("(")[0]
|
||||
prim_impl.impl(name, _prim_impl)
|
||||
prim_autograd_impl.impl(name, _autograd_impl)
|
||||
prim_meta_impl.impl(name, _wrap_tensor_meta(meta))
|
||||
|
||||
_prim = getattr(torch.ops.prims, name).default
|
||||
_prim_packet = getattr(torch.ops.prims, name)
|
||||
_prim = _prim_packet.default
|
||||
|
||||
_prim.__doc__ = doc
|
||||
_prim.meta = meta # type: ignore[attr-defined]
|
||||
_prim.impl_nvfuser = impl_nvfuser # type: ignore[attr-defined]
|
||||
_prim.return_type = return_type # type: ignore[attr-defined]
|
||||
for p in (_prim_packet, _prim):
|
||||
p.__doc__ = doc
|
||||
p.impl_nvfuser = impl_nvfuser # type: ignore[attr-defined]
|
||||
p.return_type = return_type # type: ignore[attr-defined]
|
||||
|
||||
return _prim
|
||||
|
||||
|
|
@ -355,7 +377,7 @@ def _elementwise_meta(
|
|||
*args,
|
||||
type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND,
|
||||
args_with_fixed_dtypes: Tuple[TensorLikeType, ...] = None,
|
||||
) -> TensorMeta:
|
||||
) -> FakeTensor:
|
||||
"""
|
||||
Meta function for elementwise operations that produce outputs in the same dtype
|
||||
as their inputs.
|
||||
|
|
@ -1926,7 +1948,7 @@ device_put = _make_prim(
|
|||
|
||||
# NOTE: need to model meta scalars
|
||||
# See https://github.com/pytorch/pytorch/issues/78070
|
||||
def _item_meta(a: TensorLikeType) -> TensorMeta:
|
||||
def _item_meta(a: TensorLikeType) -> FakeTensor:
|
||||
number_type = utils.dtype_to_type(a.dtype)
|
||||
return TensorMeta(number_type(-1))
|
||||
|
||||
|
|
@ -1948,7 +1970,7 @@ item = _make_prim(
|
|||
|
||||
# NOTE: need to model meta scalars
|
||||
# See https://github.com/pytorch/pytorch/issues/78070
|
||||
def _maximum_value_meta(dtype: torch.dtype) -> TensorMeta:
|
||||
def _maximum_value_meta(dtype: torch.dtype) -> FakeTensor:
|
||||
number_type = utils.dtype_to_type(dtype)
|
||||
return TensorMeta(number_type(-1))
|
||||
|
||||
|
|
@ -1980,7 +2002,7 @@ maximum_value = _make_prim(
|
|||
|
||||
# NOTE: need to model meta scalars
|
||||
# See https://github.com/pytorch/pytorch/issues/78070
|
||||
def _minimum_value_meta(dtype: torch.dtype) -> TensorMeta:
|
||||
def _minimum_value_meta(dtype: torch.dtype) -> FakeTensor:
|
||||
number_type = utils.dtype_to_type(dtype)
|
||||
return TensorMeta(number_type(-1))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,11 @@
|
|||
import string
|
||||
from typing import Callable, Sequence, Any, Dict
|
||||
from itertools import chain
|
||||
import functools
|
||||
|
||||
|
||||
import torch
|
||||
from torch.fx.graph import Graph, Node
|
||||
import torch.overrides
|
||||
|
||||
from torch._prims.utils import TensorMeta, torch_function_passthrough
|
||||
from torch._prims.utils import torch_function_passthrough
|
||||
import torch._refs as refs
|
||||
|
||||
import torch._refs
|
||||
|
|
@ -32,143 +29,6 @@ _torch_to_reference_map = {
|
|||
}
|
||||
|
||||
|
||||
class PrimContext(torch.overrides.TorchFunctionMode):
|
||||
"""
|
||||
The prototype prim tracing context.
|
||||
|
||||
Example usage:
|
||||
|
||||
import torch._prims.utils as utils
|
||||
from torch._prims.context import PrimContext
|
||||
from torch._prims.executor import execute
|
||||
from torch.overrides import push_torch_function_mode
|
||||
|
||||
a = torch.randn((2, 2))
|
||||
b = torch.randn((2, 2))
|
||||
|
||||
with push_torch_function_mode(PrimContext):
|
||||
meta_a = ctx.placeholder(utils.TensorMeta(a))
|
||||
meta_b = ctx.placeholder(utils.TensorMeta(b))
|
||||
result = torch.add(meta_a, meta_b)
|
||||
ctx.output(result)
|
||||
|
||||
exc_result = execute(ctx, a, b)
|
||||
|
||||
Currently this only acquires a trace of prims, and
|
||||
it does not account for control flow. As such,
|
||||
execute must be called with tensors that have the
|
||||
same metadata (dtype, device, shape...) as
|
||||
the tensors used to trace the operations.
|
||||
|
||||
The tracing context's FX graph can be acquired
|
||||
using its graph attribute.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.graph = Graph()
|
||||
|
||||
# Private attributes for generating names
|
||||
self._tensor_name_counter = 0
|
||||
self._dim_name_counter = 0
|
||||
self._shape_name_counter = 0
|
||||
self._lowercase = tuple(string.ascii_lowercase)
|
||||
self._uppercase = tuple(string.ascii_uppercase)
|
||||
|
||||
@staticmethod
|
||||
def _create_name(idx, chars):
|
||||
name = ""
|
||||
while idx >= len(chars):
|
||||
name = chars[idx % len(chars)] + name
|
||||
idx = idx - len(chars)
|
||||
name = chars[idx] + name
|
||||
|
||||
return name
|
||||
|
||||
def _tensor_name(self):
|
||||
idx = self._tensor_name_counter
|
||||
self._tensor_name_counter = self._tensor_name_counter + 1
|
||||
|
||||
return self._create_name(idx, self._lowercase)
|
||||
|
||||
def _add_user(self, tm: TensorMeta, node: Node) -> None:
|
||||
assert tm.node is not None
|
||||
tm.node.users[node] = None
|
||||
|
||||
def placeholder(self, a: Any):
|
||||
name = self._tensor_name()
|
||||
node = self.graph.placeholder(name)
|
||||
|
||||
if isinstance(a, TensorMeta):
|
||||
if a.node is not None:
|
||||
raise ValueError("Attempting to reuse a TensorMeta in a new trace!")
|
||||
a.tname = name
|
||||
a.node = node
|
||||
|
||||
return a
|
||||
|
||||
def output(self, tms: Sequence[TensorMeta]):
|
||||
# TODO: allow other output types
|
||||
flat_tms, _ = torch.utils._pytree.tree_flatten(tms)
|
||||
for tm in flat_tms:
|
||||
assert isinstance(tm, TensorMeta), f"Got non-TensorMeta output!, {type(tm)}"
|
||||
|
||||
node = self.graph.output(tms)
|
||||
for tm in flat_tms:
|
||||
self._add_user(tm, node)
|
||||
|
||||
def __torch_function__(
|
||||
self,
|
||||
func: Callable,
|
||||
types: Sequence,
|
||||
args: Sequence[Any] = (),
|
||||
kwargs: Dict = None,
|
||||
):
|
||||
"""
|
||||
Determines which function to call. The order of which
|
||||
function is called is determined by:
|
||||
|
||||
- func's "meta" attribute, if it exists
|
||||
- if func is a torch operation, its corresponding reference
|
||||
- func
|
||||
"""
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
if hasattr(func, "meta"):
|
||||
# TODO: add check that all args/kwargs are 'registered' properly
|
||||
# to this trace
|
||||
|
||||
output = func.meta(*args, **kwargs) # type: ignore[attr-defined]
|
||||
|
||||
# Updates graph
|
||||
# TODO: handle outputs with multiple tensors
|
||||
# TODO: handle non-tensor outputs
|
||||
assert isinstance(output, TensorMeta)
|
||||
output_name = self._tensor_name()
|
||||
node = self.graph.create_node(
|
||||
"call_function", func, name=output_name, args=args, kwargs=kwargs
|
||||
)
|
||||
output.tname = output_name
|
||||
output.node = node
|
||||
|
||||
# Marks uses
|
||||
for x in (
|
||||
x for x in chain(args, kwargs.values()) if isinstance(x, TensorMeta)
|
||||
):
|
||||
self._add_user(x, node)
|
||||
|
||||
return output
|
||||
|
||||
# Remaps torch operations to their references
|
||||
if func in _torch_to_reference_map:
|
||||
fn = _torch_to_reference_map[func]
|
||||
with torch.overrides.enable_torch_function_mode(self, replace=self.inner):
|
||||
return fn(*args, **kwargs) # type: ignore[operator]
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def torch_to_refs_map():
|
||||
"""
|
||||
|
|
@ -183,7 +43,7 @@ def torch_to_refs_map():
|
|||
]
|
||||
r = {}
|
||||
for mod_torch, mod_refs in modules:
|
||||
for s in mod_refs.__all__:
|
||||
for s in mod_refs.__all__: # type: ignore[attr-defined]
|
||||
r[mod_torch.__dict__.get(s)] = mod_refs.__dict__.get(s)
|
||||
return r
|
||||
|
||||
|
|
|
|||
|
|
@ -3,15 +3,17 @@ from typing import Callable
|
|||
import torch
|
||||
|
||||
from torch.fx import GraphModule
|
||||
from torch._prims.utils import TensorMeta, getnvFuserDtype
|
||||
from torch._prims.context import PrimContext
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.utils import getnvFuserDtype
|
||||
from torch._prims.context import TorchRefsMode
|
||||
import torch.overrides
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from torch._C._nvfuser import Fusion, FusionDefinition # type: ignore[import]
|
||||
|
||||
|
||||
def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
|
||||
def execute(gm: GraphModule, *args, executor: str = "aten", **kwargs):
|
||||
"""
|
||||
Prototype ATen executor.
|
||||
|
||||
|
|
@ -19,7 +21,6 @@ def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
|
|||
"""
|
||||
|
||||
if executor == "aten":
|
||||
gm = GraphModule({}, ctx.graph)
|
||||
return gm.forward(*args, **kwargs)
|
||||
elif executor == "nvfuser":
|
||||
if not torch.cuda.is_available():
|
||||
|
|
@ -28,36 +29,32 @@ def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
|
|||
)
|
||||
|
||||
# PROTOTYPE nvfuser executor
|
||||
# Only accepts tensor inputs and single tensor outputs
|
||||
# Does not handle kwargs
|
||||
# Does not support reusing the same ctx to execute!
|
||||
assert len(kwargs) == 0
|
||||
# TODO: make this a proper trace -> trace transform that
|
||||
# doesn't mutate the context
|
||||
graph_fd = ctx.graph.placeholder("fd")
|
||||
ctx.graph._root.append(graph_fd)
|
||||
# Everything in the graph must support nvfuser
|
||||
|
||||
fusion = Fusion()
|
||||
with FusionDefinition(fusion) as fd:
|
||||
# Transforms graph to call nvfuser lowerings
|
||||
nv_args = [fd]
|
||||
for arg in args:
|
||||
|
||||
class FusionInterpreter(torch.fx.Interpreter):
|
||||
def call_function(self, target, args, kwargs):
|
||||
target = target.impl_nvfuser
|
||||
args = (fd,) + args
|
||||
return target(*args, **kwargs)
|
||||
|
||||
def to_nv(arg):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
x = fd.define_tensor(
|
||||
arg.size(), arg.stride(), getnvFuserDtype(arg.dtype)
|
||||
)
|
||||
fd.add_input(x)
|
||||
nv_args.append(x)
|
||||
return x
|
||||
else:
|
||||
nv_args.append(x)
|
||||
return arg
|
||||
|
||||
for x in ctx.graph.nodes:
|
||||
if x.op == "call_function":
|
||||
x.target = x.target.impl_nvfuser
|
||||
x.args = (graph_fd,) + x.args
|
||||
# Transforms graph to call nvfuser lowerings
|
||||
nv_args = tree_map(to_nv, args)
|
||||
nv_kwargs = tree_map(to_nv, kwargs)
|
||||
|
||||
gm = GraphModule({}, ctx.graph)
|
||||
out = gm.forward(*nv_args)
|
||||
out = FusionInterpreter(gm).run(*nv_args, **nv_kwargs)
|
||||
flat_out, unflatten_spec = torch.utils._pytree.tree_flatten(out)
|
||||
for o in flat_out:
|
||||
fd.add_output(o)
|
||||
|
|
@ -102,17 +99,9 @@ def make_traced(fn: Callable):
|
|||
"""
|
||||
|
||||
def _traced(*args, executor="aten"):
|
||||
ctx: PrimContext
|
||||
with torch.overrides.push_torch_function_mode(PrimContext) as ctx: # type: ignore[attr-defined, assignment]
|
||||
placeholders = []
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
placeholders.append(ctx.placeholder(TensorMeta(arg)))
|
||||
else:
|
||||
placeholders.append(ctx.placeholder(arg))
|
||||
|
||||
result = fn(*placeholders)
|
||||
ctx.output(result)
|
||||
return execute(ctx, *args, executor=executor)
|
||||
# TODO: caching
|
||||
with TorchRefsMode.push():
|
||||
gm = make_fx(fn)(*args)
|
||||
return execute(gm, *args, executor=executor)
|
||||
|
||||
return _traced
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Union, Sequence, Optional, Callable, Dict, Tuple, List
|
||||
from typing import Any, Union, Sequence, Optional, Tuple, List
|
||||
from enum import Enum
|
||||
from functools import reduce, cmp_to_key
|
||||
import operator
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -55,119 +56,57 @@ torch_function_passthrough = {
|
|||
}
|
||||
|
||||
|
||||
class TensorMeta(torch.Tensor):
|
||||
"""
|
||||
Model tensor metadata. Not a stock meta tensor because device is modeled
|
||||
as the original device (not meta device), also we have different behavior
|
||||
for some high level Python bindings
|
||||
"""
|
||||
|
||||
# Note: this will be an fx Node if it's ever
|
||||
# populated, but some Meta-internal jobs don't include fx
|
||||
node: Optional[Any]
|
||||
tname: str
|
||||
|
||||
@staticmethod
|
||||
def __new__(
|
||||
cls,
|
||||
tensorlike: Optional[Union[TensorMeta, NumberType, torch.Tensor]] = None,
|
||||
*,
|
||||
shape: Optional[ShapeType] = None,
|
||||
strides: Optional[StrideType] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[Union[torch.device, str]] = None,
|
||||
):
|
||||
|
||||
if isinstance(tensorlike, Number):
|
||||
assert not shape and (shape is None or isinstance(shape, Sequence))
|
||||
assert not strides and (strides is None or isinstance(strides, Sequence))
|
||||
inferred_shape: Tuple[int, ...] = ()
|
||||
inferred_strides: Tuple[int, ...] = ()
|
||||
inferred_dtype = type_to_dtype(type(tensorlike))
|
||||
inferred_device = torch.device("cpu")
|
||||
# TODO: This looks wrong, a number that is wrapped into a tensor
|
||||
# needs to behave differently than a scalar tensor for type
|
||||
# promotion purposes
|
||||
elif tensorlike is not None:
|
||||
assert isinstance(tensorlike, (TensorMeta, torch.Tensor))
|
||||
inferred_shape = tuple(tensorlike.shape)
|
||||
inferred_strides = tuple(tensorlike.stride())
|
||||
inferred_dtype = tensorlike.dtype
|
||||
inferred_device = tensorlike.device
|
||||
else:
|
||||
# If no tensorlike "example" is given then all metadata
|
||||
# must be provided explicitly
|
||||
assert shape is not None
|
||||
assert strides is not None
|
||||
assert dtype is not None
|
||||
assert device is not None
|
||||
|
||||
shape = inferred_shape if shape is None else tuple(shape)
|
||||
strides = inferred_strides if strides is None else tuple(strides)
|
||||
dtype = inferred_dtype if dtype is None else dtype
|
||||
device = inferred_device if device is None else device
|
||||
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
|
||||
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
||||
cls,
|
||||
shape,
|
||||
strides=strides,
|
||||
storage_offset=0, # TODO: this is inaccurate
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
r.tname = ""
|
||||
r.node = None
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(
|
||||
cls,
|
||||
func: Callable,
|
||||
types: Sequence,
|
||||
args: Sequence[Any] = (),
|
||||
kwargs: Optional[Dict] = None,
|
||||
):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
if func in torch_function_passthrough:
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
if not hasattr(func, "meta"):
|
||||
raise ValueError(f"Callable {func} has no meta function!")
|
||||
|
||||
return func.meta(*args, **kwargs) # type: ignore[attr-defined]
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(
|
||||
cls,
|
||||
func,
|
||||
types,
|
||||
args=(),
|
||||
kwargs=None,
|
||||
):
|
||||
raise RuntimeError("this should be unreachable")
|
||||
|
||||
# TODO: fx uses dunder repr to print objects in code
|
||||
def __repr__(self):
|
||||
return self.tname
|
||||
# return f"TensorMeta(dtype={self.dtype}, device={self.device}, shape={self.shape}, strides={self.stride()})"
|
||||
|
||||
def __format__(self, format_spec):
|
||||
return self.tname
|
||||
|
||||
|
||||
TensorLikeType = Union[torch.Tensor, TensorMeta]
|
||||
TensorLike = (torch.Tensor, TensorMeta)
|
||||
TensorLikeType = torch.Tensor
|
||||
TensorLike = torch.Tensor
|
||||
TensorSequenceType = Union[List[TensorLikeType], Tuple[TensorLikeType, ...]]
|
||||
TensorOrNumberLikeType = Union[TensorLikeType, NumberType]
|
||||
|
||||
|
||||
def TensorMeta(
|
||||
tensorlike: Optional[Union[NumberType, torch.Tensor]] = None,
|
||||
*,
|
||||
shape: Optional[ShapeType] = None,
|
||||
strides: Optional[StrideType] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[Union[torch.device, str]] = None,
|
||||
):
|
||||
if isinstance(tensorlike, Number):
|
||||
assert not shape and (shape is None or isinstance(shape, Sequence))
|
||||
assert not strides and (strides is None or isinstance(strides, Sequence))
|
||||
inferred_shape: Tuple[int, ...] = ()
|
||||
inferred_strides: Tuple[int, ...] = ()
|
||||
inferred_dtype = type_to_dtype(type(tensorlike))
|
||||
inferred_device = torch.device("cpu")
|
||||
# TODO: This looks wrong, a number that is wrapped into a tensor
|
||||
# needs to behave differently than a scalar tensor for type
|
||||
# promotion purposes
|
||||
elif tensorlike is not None:
|
||||
assert isinstance(tensorlike, torch.Tensor)
|
||||
inferred_shape = tuple(tensorlike.shape)
|
||||
inferred_strides = tuple(tensorlike.stride())
|
||||
inferred_dtype = tensorlike.dtype
|
||||
inferred_device = tensorlike.device
|
||||
else:
|
||||
# If no tensorlike "example" is given then all metadata
|
||||
# must be provided explicitly
|
||||
assert shape is not None
|
||||
assert strides is not None
|
||||
assert dtype is not None
|
||||
assert device is not None
|
||||
|
||||
shape = inferred_shape if shape is None else tuple(shape)
|
||||
strides = inferred_strides if strides is None else tuple(strides)
|
||||
dtype = inferred_dtype if dtype is None else dtype
|
||||
device = inferred_device if device is None else device
|
||||
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
|
||||
return FakeTensor(
|
||||
torch.empty_strided(shape, strides, dtype=dtype, device="meta"), device
|
||||
)
|
||||
|
||||
|
||||
# TODO: look at using torch.testing.assert_close instead with an option
|
||||
# to just compare metadata
|
||||
def compare_tensor_meta(a: TensorLikeType, b: TensorLikeType):
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ import operator
|
|||
import warnings
|
||||
import math
|
||||
from enum import Enum
|
||||
import collections
|
||||
|
||||
# Experimental module containing prototype Python references for existing
|
||||
# PyTorch operations.
|
||||
|
|
@ -1576,10 +1577,14 @@ def addr(
|
|||
|
||||
|
||||
def atleast_1d(
|
||||
*args: TensorLikeType,
|
||||
arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
|
||||
) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
|
||||
"""Reference implementation of :func:`torch.atleast_1d`."""
|
||||
args_ = args[0] if len(args) == 1 and not torch.is_tensor(args[0]) else args
|
||||
if not args and isinstance(arg, collections.Sequence):
|
||||
args_ = arg
|
||||
else:
|
||||
assert not isinstance(arg, collections.Sequence)
|
||||
args_ = (arg,) + args
|
||||
res = tuple(a if a.ndim >= 1 else unsqueeze(a, 0) for a in args_)
|
||||
return res if len(res) > 1 else res[0]
|
||||
|
||||
|
|
@ -1595,20 +1600,28 @@ def _unsqueeze_atleast(
|
|||
|
||||
|
||||
def atleast_2d(
|
||||
*args: TensorLikeType,
|
||||
arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
|
||||
) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
|
||||
"""Reference implementation of :func:`torch.atleast_2d`."""
|
||||
args_ = args[0] if len(args) == 1 and not torch.is_tensor(args[0]) else args
|
||||
if not args and isinstance(arg, collections.Sequence):
|
||||
args_ = arg
|
||||
else:
|
||||
assert not isinstance(arg, collections.Sequence)
|
||||
args_ = (arg,) + args
|
||||
unsqueeze_atleast_1d = partial(_unsqueeze_atleast, atleast_1d, 0)
|
||||
res = tuple(a if a.ndim >= 2 else unsqueeze_atleast_1d(a) for a in args_)
|
||||
return res if len(res) > 1 else res[0]
|
||||
|
||||
|
||||
def atleast_3d(
|
||||
*args: TensorLikeType,
|
||||
arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
|
||||
) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
|
||||
"""Reference implementation of :func:`torch.atleast_3d`."""
|
||||
args_ = args[0] if len(args) == 1 and not torch.is_tensor(args[0]) else args
|
||||
if not args and isinstance(arg, collections.Sequence):
|
||||
args_ = arg
|
||||
else:
|
||||
assert not isinstance(arg, collections.Sequence)
|
||||
args_ = (arg,) + args
|
||||
unsqueeze_atleast_2d = partial(_unsqueeze_atleast, atleast_2d, -1)
|
||||
res = tuple(a if a.ndim >= 3 else unsqueeze_atleast_2d(a) for a in args_)
|
||||
return res if len(res) > 1 else res[0]
|
||||
|
|
|
|||
|
|
@ -180,7 +180,6 @@ class FakeTensor(torch.Tensor):
|
|||
# elem does not need to be recorded, because FakeTensor *is a* elem
|
||||
assert elem.device.type == "meta"
|
||||
device = device if isinstance(device, torch.device) else torch.device(device)
|
||||
assert device.type != "meta"
|
||||
self.fake_device = device
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import torch
|
||||
from torch._prims.utils import is_complex_dtype, corresponding_real_dtype
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
|
||||
def safe_is_leaf(t):
|
||||
|
|
@ -50,8 +49,8 @@ class MetaConverter:
|
|||
base = self.meta_tensor(t._base)
|
||||
|
||||
def is_c_of_r(complex_dtype, real_dtype):
|
||||
return is_complex_dtype(complex_dtype) and \
|
||||
corresponding_real_dtype(complex_dtype) == real_dtype
|
||||
return utils.is_complex_dtype(complex_dtype) and \
|
||||
utils.corresponding_real_dtype(complex_dtype) == real_dtype
|
||||
|
||||
if base.dtype == t.dtype:
|
||||
pass
|
||||
|
|
@ -138,3 +137,5 @@ class MetaConverter:
|
|||
else:
|
||||
# non-Tensor types don't count as hit or miss
|
||||
return t
|
||||
|
||||
import torch._prims.utils as utils
|
||||
|
|
|
|||
|
|
@ -226,6 +226,13 @@ void initDispatchBindings(PyObject* module) {
|
|||
return states;
|
||||
});
|
||||
|
||||
m.def("_dispatch_tls_set_dispatch_key_excluded", [](const char* dispatch_key, bool desired_state) {
|
||||
c10::impl::tls_set_dispatch_key_excluded(c10::parseDispatchKey(dispatch_key), desired_state);
|
||||
});
|
||||
m.def("_dispatch_tls_is_dispatch_key_excluded", [](const char* dispatch_key) {
|
||||
return c10::impl::tls_is_dispatch_key_excluded(c10::parseDispatchKey(dispatch_key));
|
||||
});
|
||||
|
||||
// Prints out the name of every operator that has a kernel registered to the Dispatcher
|
||||
// under [dispatch_key].
|
||||
// If no arguments are specified, it'll print out the name of every operator that the Dispatcher knows of.
|
||||
|
|
|
|||
|
|
@ -19416,6 +19416,12 @@ def _inherit_constructor_args(name, op, inherited, overrides):
|
|||
kwargs.update(common_kwargs)
|
||||
kwargs.update(overrides)
|
||||
|
||||
kwargs['supports_autograd'] = False
|
||||
kwargs['supports_gradgrad'] = False
|
||||
kwargs['supports_fwgrad_bwgrad'] = False
|
||||
kwargs['supports_inplace_autograd'] = False
|
||||
kwargs['supports_forward_ad'] = False
|
||||
|
||||
return kwargs
|
||||
|
||||
class PythonRefInfo(OpInfo):
|
||||
|
|
@ -19755,10 +19761,6 @@ python_ref_db = [
|
|||
PythonRefInfo(
|
||||
"_refs.nn.functional.leaky_relu",
|
||||
torch_opinfo_name="nn.functional.leaky_relu",
|
||||
decorators=(
|
||||
# Need FakeTensor support for meta coverage
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',),
|
||||
),
|
||||
),
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.nn.functional.relu",
|
||||
|
|
@ -20078,11 +20080,6 @@ python_ref_db = [
|
|||
PythonRefInfo(
|
||||
"_refs.stack",
|
||||
torch_opinfo_name="stack",
|
||||
skips=(
|
||||
# https://github.com/pytorch/pytorch/issues/77046
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
|
||||
),
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.squeeze",
|
||||
|
|
@ -20105,10 +20102,6 @@ python_ref_db = [
|
|||
PythonRefInfo(
|
||||
"_refs.t",
|
||||
torch_opinfo_name="t",
|
||||
decorators=(
|
||||
# Need FakeTensor support for meta coverage
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',),
|
||||
),
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.unsqueeze",
|
||||
|
|
@ -20180,8 +20173,6 @@ python_ref_db = [
|
|||
"_refs.addr",
|
||||
torch_opinfo_name="addr",
|
||||
decorators=(
|
||||
# RuntimeError: no _refs support for torch.outer
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',),
|
||||
),
|
||||
),
|
||||
|
|
|
|||
Loading…
Reference in a new issue