Replace TensorMeta with FakeTensor

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78836

Approved by: https://github.com/albanD, https://github.com/mruberry
This commit is contained in:
Edward Z. Yang 2022-06-04 19:53:26 -07:00 committed by PyTorch MergeBot
parent 484282a6fd
commit 587efdb5fa
11 changed files with 160 additions and 336 deletions

View file

@ -14,6 +14,7 @@ from torch.testing._internal.common_dtype import (
floating_and_complex_types_and,
all_types_and_complex_and,
)
from torch._subclasses.fake_tensor import FakeTensor
from torch.testing._internal.common_utils import (
TestCase,
is_iterable_of_tensors,
@ -356,7 +357,7 @@ class TestCommon(TestCase):
def _to_tensormeta(x):
if isinstance(x, torch.Tensor):
return prims.utils.TensorMeta(x)
return FakeTensor.from_tensor(x)
return x
# TODO: iterate over requires_grad true/false
@ -506,7 +507,7 @@ class TestCommon(TestCase):
def test_python_ref_errors(self, device, op):
def _to_tensormeta(x):
if isinstance(x, torch.Tensor):
return prims.utils.TensorMeta(x)
return FakeTensor.from_tensor(x)
return x
error_inputs = op.error_inputs(device)

View file

@ -840,6 +840,8 @@ class Generator(object):
def _dispatch_library(kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0) -> Any: ...
def _dispatch_has_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ...
def _dispatch_has_kernel(name: str) -> _bool: ...
def _dispatch_tls_is_dispatch_key_excluded(dispatch: str) -> _bool: ...
def _dispatch_tls_set_dispatch_key_excluded(dispatch: str, val: _bool) -> None: ...
# Defined in torch/csrc/utils/init.cpp
class BenchmarkConfig(object):

View file

@ -5,7 +5,6 @@ import torch._prims.utils as utils
from torch._prims.utils import (
TensorLike,
TensorLikeType,
TensorMeta,
ShapeType,
getnvFuserDtype,
DimsType,
@ -13,11 +12,14 @@ from torch._prims.utils import (
StrideType,
Number,
NumberType,
TensorMeta,
)
from torch.overrides import has_torch_function, handle_torch_function
import torch.library
from torch.utils._pytree import tree_map
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
from torch._subclasses.fake_tensor import FakeTensor
import contextlib
from typing import Sequence, Optional, Union, Callable, List, Tuple, Any, Type
from functools import reduce, partial
from enum import Enum
@ -26,6 +28,7 @@ import math
prim = torch.library.Library("prims", "DEF")
prim_impl = torch.library.Library("prims", "IMPL", "CompositeExplicitAutograd")
prim_autograd_impl = torch.library.Library("prims", "IMPL", "Autograd")
prim_meta_impl = torch.library.Library("prims", "IMPL", "Meta")
# Experimental module containing prototype "primitive" operations.
@ -286,27 +289,29 @@ class RETURN_TYPE(Enum):
def _wrap_tensor_meta(f):
def wrap(t):
if isinstance(t, torch.Tensor):
return TensorMeta(t)
else:
return t
def unwrap(t):
# TODO: doesn't setup aliasing relation on views correctly
if isinstance(t, TensorMeta):
return torch.empty_strided(
t.shape, t.stride(), dtype=t.dtype, device="meta"
)
return FakeTensor.from_tensor(t)
else:
return t
def wrapper(*args, **kwargs):
wrapped_args = tree_map(wrap, args)
wrapped_kwargs = tree_map(wrap, kwargs)
return tree_map(unwrap, f(*wrapped_args, **wrapped_kwargs))
return f(*wrapped_args, **wrapped_kwargs)
return wrapper
@contextlib.contextmanager
def _DispatchBelowAutograd():
# TODO: AutogradOther
old = torch._C._dispatch_tls_is_dispatch_key_excluded("AutogradFunctionality")
torch._C._dispatch_tls_set_dispatch_key_excluded("AutogradFunctionality", True)
try:
yield
finally:
torch._C._dispatch_tls_set_dispatch_key_excluded("AutogradFunctionality", old)
def _make_prim(
*,
schema: str,
@ -330,16 +335,33 @@ def _make_prim(
meta(*args, **kwargs)
return impl_aten(*args, **kwargs)
class BackwardsNotSupported(torch.autograd.Function):
@staticmethod
def forward(ctx, args_spec, *flat_args):
args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type]
with _DispatchBelowAutograd():
return _prim(*args, **kwargs)
@staticmethod
def backward(ctx, *args):
raise RuntimeError("backwards not supported on prim")
def _autograd_impl(*args, **kwargs):
flat_args, args_spec = tree_flatten((args, kwargs))
return BackwardsNotSupported.apply(args_spec, *flat_args)
name = schema.split("(")[0]
prim_impl.impl(name, _prim_impl)
prim_autograd_impl.impl(name, _autograd_impl)
prim_meta_impl.impl(name, _wrap_tensor_meta(meta))
_prim = getattr(torch.ops.prims, name).default
_prim_packet = getattr(torch.ops.prims, name)
_prim = _prim_packet.default
_prim.__doc__ = doc
_prim.meta = meta # type: ignore[attr-defined]
_prim.impl_nvfuser = impl_nvfuser # type: ignore[attr-defined]
_prim.return_type = return_type # type: ignore[attr-defined]
for p in (_prim_packet, _prim):
p.__doc__ = doc
p.impl_nvfuser = impl_nvfuser # type: ignore[attr-defined]
p.return_type = return_type # type: ignore[attr-defined]
return _prim
@ -355,7 +377,7 @@ def _elementwise_meta(
*args,
type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND,
args_with_fixed_dtypes: Tuple[TensorLikeType, ...] = None,
) -> TensorMeta:
) -> FakeTensor:
"""
Meta function for elementwise operations that produce outputs in the same dtype
as their inputs.
@ -1926,7 +1948,7 @@ device_put = _make_prim(
# NOTE: need to model meta scalars
# See https://github.com/pytorch/pytorch/issues/78070
def _item_meta(a: TensorLikeType) -> TensorMeta:
def _item_meta(a: TensorLikeType) -> FakeTensor:
number_type = utils.dtype_to_type(a.dtype)
return TensorMeta(number_type(-1))
@ -1948,7 +1970,7 @@ item = _make_prim(
# NOTE: need to model meta scalars
# See https://github.com/pytorch/pytorch/issues/78070
def _maximum_value_meta(dtype: torch.dtype) -> TensorMeta:
def _maximum_value_meta(dtype: torch.dtype) -> FakeTensor:
number_type = utils.dtype_to_type(dtype)
return TensorMeta(number_type(-1))
@ -1980,7 +2002,7 @@ maximum_value = _make_prim(
# NOTE: need to model meta scalars
# See https://github.com/pytorch/pytorch/issues/78070
def _minimum_value_meta(dtype: torch.dtype) -> TensorMeta:
def _minimum_value_meta(dtype: torch.dtype) -> FakeTensor:
number_type = utils.dtype_to_type(dtype)
return TensorMeta(number_type(-1))

View file

@ -1,14 +1,11 @@
import string
from typing import Callable, Sequence, Any, Dict
from itertools import chain
import functools
import torch
from torch.fx.graph import Graph, Node
import torch.overrides
from torch._prims.utils import TensorMeta, torch_function_passthrough
from torch._prims.utils import torch_function_passthrough
import torch._refs as refs
import torch._refs
@ -32,143 +29,6 @@ _torch_to_reference_map = {
}
class PrimContext(torch.overrides.TorchFunctionMode):
"""
The prototype prim tracing context.
Example usage:
import torch._prims.utils as utils
from torch._prims.context import PrimContext
from torch._prims.executor import execute
from torch.overrides import push_torch_function_mode
a = torch.randn((2, 2))
b = torch.randn((2, 2))
with push_torch_function_mode(PrimContext):
meta_a = ctx.placeholder(utils.TensorMeta(a))
meta_b = ctx.placeholder(utils.TensorMeta(b))
result = torch.add(meta_a, meta_b)
ctx.output(result)
exc_result = execute(ctx, a, b)
Currently this only acquires a trace of prims, and
it does not account for control flow. As such,
execute must be called with tensors that have the
same metadata (dtype, device, shape...) as
the tensors used to trace the operations.
The tracing context's FX graph can be acquired
using its graph attribute.
"""
def __init__(self):
self.graph = Graph()
# Private attributes for generating names
self._tensor_name_counter = 0
self._dim_name_counter = 0
self._shape_name_counter = 0
self._lowercase = tuple(string.ascii_lowercase)
self._uppercase = tuple(string.ascii_uppercase)
@staticmethod
def _create_name(idx, chars):
name = ""
while idx >= len(chars):
name = chars[idx % len(chars)] + name
idx = idx - len(chars)
name = chars[idx] + name
return name
def _tensor_name(self):
idx = self._tensor_name_counter
self._tensor_name_counter = self._tensor_name_counter + 1
return self._create_name(idx, self._lowercase)
def _add_user(self, tm: TensorMeta, node: Node) -> None:
assert tm.node is not None
tm.node.users[node] = None
def placeholder(self, a: Any):
name = self._tensor_name()
node = self.graph.placeholder(name)
if isinstance(a, TensorMeta):
if a.node is not None:
raise ValueError("Attempting to reuse a TensorMeta in a new trace!")
a.tname = name
a.node = node
return a
def output(self, tms: Sequence[TensorMeta]):
# TODO: allow other output types
flat_tms, _ = torch.utils._pytree.tree_flatten(tms)
for tm in flat_tms:
assert isinstance(tm, TensorMeta), f"Got non-TensorMeta output!, {type(tm)}"
node = self.graph.output(tms)
for tm in flat_tms:
self._add_user(tm, node)
def __torch_function__(
self,
func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Dict = None,
):
"""
Determines which function to call. The order of which
function is called is determined by:
- func's "meta" attribute, if it exists
- if func is a torch operation, its corresponding reference
- func
"""
if kwargs is None:
kwargs = {}
if hasattr(func, "meta"):
# TODO: add check that all args/kwargs are 'registered' properly
# to this trace
output = func.meta(*args, **kwargs) # type: ignore[attr-defined]
# Updates graph
# TODO: handle outputs with multiple tensors
# TODO: handle non-tensor outputs
assert isinstance(output, TensorMeta)
output_name = self._tensor_name()
node = self.graph.create_node(
"call_function", func, name=output_name, args=args, kwargs=kwargs
)
output.tname = output_name
output.node = node
# Marks uses
for x in (
x for x in chain(args, kwargs.values()) if isinstance(x, TensorMeta)
):
self._add_user(x, node)
return output
# Remaps torch operations to their references
if func in _torch_to_reference_map:
fn = _torch_to_reference_map[func]
with torch.overrides.enable_torch_function_mode(self, replace=self.inner):
return fn(*args, **kwargs) # type: ignore[operator]
return func(*args, **kwargs)
@functools.lru_cache(None)
def torch_to_refs_map():
"""
@ -183,7 +43,7 @@ def torch_to_refs_map():
]
r = {}
for mod_torch, mod_refs in modules:
for s in mod_refs.__all__:
for s in mod_refs.__all__: # type: ignore[attr-defined]
r[mod_torch.__dict__.get(s)] = mod_refs.__dict__.get(s)
return r

View file

@ -3,15 +3,17 @@ from typing import Callable
import torch
from torch.fx import GraphModule
from torch._prims.utils import TensorMeta, getnvFuserDtype
from torch._prims.context import PrimContext
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.utils import getnvFuserDtype
from torch._prims.context import TorchRefsMode
import torch.overrides
from torch.utils._pytree import tree_map
if torch.cuda.is_available():
from torch._C._nvfuser import Fusion, FusionDefinition # type: ignore[import]
def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
def execute(gm: GraphModule, *args, executor: str = "aten", **kwargs):
"""
Prototype ATen executor.
@ -19,7 +21,6 @@ def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
"""
if executor == "aten":
gm = GraphModule({}, ctx.graph)
return gm.forward(*args, **kwargs)
elif executor == "nvfuser":
if not torch.cuda.is_available():
@ -28,36 +29,32 @@ def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
)
# PROTOTYPE nvfuser executor
# Only accepts tensor inputs and single tensor outputs
# Does not handle kwargs
# Does not support reusing the same ctx to execute!
assert len(kwargs) == 0
# TODO: make this a proper trace -> trace transform that
# doesn't mutate the context
graph_fd = ctx.graph.placeholder("fd")
ctx.graph._root.append(graph_fd)
# Everything in the graph must support nvfuser
fusion = Fusion()
with FusionDefinition(fusion) as fd:
# Transforms graph to call nvfuser lowerings
nv_args = [fd]
for arg in args:
class FusionInterpreter(torch.fx.Interpreter):
def call_function(self, target, args, kwargs):
target = target.impl_nvfuser
args = (fd,) + args
return target(*args, **kwargs)
def to_nv(arg):
if isinstance(arg, torch.Tensor):
x = fd.define_tensor(
arg.size(), arg.stride(), getnvFuserDtype(arg.dtype)
)
fd.add_input(x)
nv_args.append(x)
return x
else:
nv_args.append(x)
return arg
for x in ctx.graph.nodes:
if x.op == "call_function":
x.target = x.target.impl_nvfuser
x.args = (graph_fd,) + x.args
# Transforms graph to call nvfuser lowerings
nv_args = tree_map(to_nv, args)
nv_kwargs = tree_map(to_nv, kwargs)
gm = GraphModule({}, ctx.graph)
out = gm.forward(*nv_args)
out = FusionInterpreter(gm).run(*nv_args, **nv_kwargs)
flat_out, unflatten_spec = torch.utils._pytree.tree_flatten(out)
for o in flat_out:
fd.add_output(o)
@ -102,17 +99,9 @@ def make_traced(fn: Callable):
"""
def _traced(*args, executor="aten"):
ctx: PrimContext
with torch.overrides.push_torch_function_mode(PrimContext) as ctx: # type: ignore[attr-defined, assignment]
placeholders = []
for arg in args:
if isinstance(arg, torch.Tensor):
placeholders.append(ctx.placeholder(TensorMeta(arg)))
else:
placeholders.append(ctx.placeholder(arg))
result = fn(*placeholders)
ctx.output(result)
return execute(ctx, *args, executor=executor)
# TODO: caching
with TorchRefsMode.push():
gm = make_fx(fn)(*args)
return execute(gm, *args, executor=executor)
return _traced

View file

@ -1,9 +1,10 @@
from __future__ import annotations
from typing import Any, Union, Sequence, Optional, Callable, Dict, Tuple, List
from typing import Any, Union, Sequence, Optional, Tuple, List
from enum import Enum
from functools import reduce, cmp_to_key
import operator
from torch._subclasses.fake_tensor import FakeTensor
import torch
@ -55,119 +56,57 @@ torch_function_passthrough = {
}
class TensorMeta(torch.Tensor):
"""
Model tensor metadata. Not a stock meta tensor because device is modeled
as the original device (not meta device), also we have different behavior
for some high level Python bindings
"""
# Note: this will be an fx Node if it's ever
# populated, but some Meta-internal jobs don't include fx
node: Optional[Any]
tname: str
@staticmethod
def __new__(
cls,
tensorlike: Optional[Union[TensorMeta, NumberType, torch.Tensor]] = None,
*,
shape: Optional[ShapeType] = None,
strides: Optional[StrideType] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str]] = None,
):
if isinstance(tensorlike, Number):
assert not shape and (shape is None or isinstance(shape, Sequence))
assert not strides and (strides is None or isinstance(strides, Sequence))
inferred_shape: Tuple[int, ...] = ()
inferred_strides: Tuple[int, ...] = ()
inferred_dtype = type_to_dtype(type(tensorlike))
inferred_device = torch.device("cpu")
# TODO: This looks wrong, a number that is wrapped into a tensor
# needs to behave differently than a scalar tensor for type
# promotion purposes
elif tensorlike is not None:
assert isinstance(tensorlike, (TensorMeta, torch.Tensor))
inferred_shape = tuple(tensorlike.shape)
inferred_strides = tuple(tensorlike.stride())
inferred_dtype = tensorlike.dtype
inferred_device = tensorlike.device
else:
# If no tensorlike "example" is given then all metadata
# must be provided explicitly
assert shape is not None
assert strides is not None
assert dtype is not None
assert device is not None
shape = inferred_shape if shape is None else tuple(shape)
strides = inferred_strides if strides is None else tuple(strides)
dtype = inferred_dtype if dtype is None else dtype
device = inferred_device if device is None else device
if isinstance(device, str):
device = torch.device(device)
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
shape,
strides=strides,
storage_offset=0, # TODO: this is inaccurate
dtype=dtype,
device=device,
requires_grad=False,
)
r.tname = ""
r.node = None
return r
@classmethod
def __torch_function__(
cls,
func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Optional[Dict] = None,
):
if kwargs is None:
kwargs = {}
if func in torch_function_passthrough:
return super().__torch_function__(func, types, args, kwargs)
if not hasattr(func, "meta"):
raise ValueError(f"Callable {func} has no meta function!")
return func.meta(*args, **kwargs) # type: ignore[attr-defined]
@classmethod
def __torch_dispatch__(
cls,
func,
types,
args=(),
kwargs=None,
):
raise RuntimeError("this should be unreachable")
# TODO: fx uses dunder repr to print objects in code
def __repr__(self):
return self.tname
# return f"TensorMeta(dtype={self.dtype}, device={self.device}, shape={self.shape}, strides={self.stride()})"
def __format__(self, format_spec):
return self.tname
TensorLikeType = Union[torch.Tensor, TensorMeta]
TensorLike = (torch.Tensor, TensorMeta)
TensorLikeType = torch.Tensor
TensorLike = torch.Tensor
TensorSequenceType = Union[List[TensorLikeType], Tuple[TensorLikeType, ...]]
TensorOrNumberLikeType = Union[TensorLikeType, NumberType]
def TensorMeta(
tensorlike: Optional[Union[NumberType, torch.Tensor]] = None,
*,
shape: Optional[ShapeType] = None,
strides: Optional[StrideType] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str]] = None,
):
if isinstance(tensorlike, Number):
assert not shape and (shape is None or isinstance(shape, Sequence))
assert not strides and (strides is None or isinstance(strides, Sequence))
inferred_shape: Tuple[int, ...] = ()
inferred_strides: Tuple[int, ...] = ()
inferred_dtype = type_to_dtype(type(tensorlike))
inferred_device = torch.device("cpu")
# TODO: This looks wrong, a number that is wrapped into a tensor
# needs to behave differently than a scalar tensor for type
# promotion purposes
elif tensorlike is not None:
assert isinstance(tensorlike, torch.Tensor)
inferred_shape = tuple(tensorlike.shape)
inferred_strides = tuple(tensorlike.stride())
inferred_dtype = tensorlike.dtype
inferred_device = tensorlike.device
else:
# If no tensorlike "example" is given then all metadata
# must be provided explicitly
assert shape is not None
assert strides is not None
assert dtype is not None
assert device is not None
shape = inferred_shape if shape is None else tuple(shape)
strides = inferred_strides if strides is None else tuple(strides)
dtype = inferred_dtype if dtype is None else dtype
device = inferred_device if device is None else device
if isinstance(device, str):
device = torch.device(device)
return FakeTensor(
torch.empty_strided(shape, strides, dtype=dtype, device="meta"), device
)
# TODO: look at using torch.testing.assert_close instead with an option
# to just compare metadata
def compare_tensor_meta(a: TensorLikeType, b: TensorLikeType):

View file

@ -36,6 +36,7 @@ import operator
import warnings
import math
from enum import Enum
import collections
# Experimental module containing prototype Python references for existing
# PyTorch operations.
@ -1576,10 +1577,14 @@ def addr(
def atleast_1d(
*args: TensorLikeType,
arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
"""Reference implementation of :func:`torch.atleast_1d`."""
args_ = args[0] if len(args) == 1 and not torch.is_tensor(args[0]) else args
if not args and isinstance(arg, collections.Sequence):
args_ = arg
else:
assert not isinstance(arg, collections.Sequence)
args_ = (arg,) + args
res = tuple(a if a.ndim >= 1 else unsqueeze(a, 0) for a in args_)
return res if len(res) > 1 else res[0]
@ -1595,20 +1600,28 @@ def _unsqueeze_atleast(
def atleast_2d(
*args: TensorLikeType,
arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
"""Reference implementation of :func:`torch.atleast_2d`."""
args_ = args[0] if len(args) == 1 and not torch.is_tensor(args[0]) else args
if not args and isinstance(arg, collections.Sequence):
args_ = arg
else:
assert not isinstance(arg, collections.Sequence)
args_ = (arg,) + args
unsqueeze_atleast_1d = partial(_unsqueeze_atleast, atleast_1d, 0)
res = tuple(a if a.ndim >= 2 else unsqueeze_atleast_1d(a) for a in args_)
return res if len(res) > 1 else res[0]
def atleast_3d(
*args: TensorLikeType,
arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
"""Reference implementation of :func:`torch.atleast_3d`."""
args_ = args[0] if len(args) == 1 and not torch.is_tensor(args[0]) else args
if not args and isinstance(arg, collections.Sequence):
args_ = arg
else:
assert not isinstance(arg, collections.Sequence)
args_ = (arg,) + args
unsqueeze_atleast_2d = partial(_unsqueeze_atleast, atleast_2d, -1)
res = tuple(a if a.ndim >= 3 else unsqueeze_atleast_2d(a) for a in args_)
return res if len(res) > 1 else res[0]

View file

@ -180,7 +180,6 @@ class FakeTensor(torch.Tensor):
# elem does not need to be recorded, because FakeTensor *is a* elem
assert elem.device.type == "meta"
device = device if isinstance(device, torch.device) else torch.device(device)
assert device.type != "meta"
self.fake_device = device
@staticmethod

View file

@ -1,5 +1,4 @@
import torch
from torch._prims.utils import is_complex_dtype, corresponding_real_dtype
from torch.utils._mode_utils import no_dispatch
def safe_is_leaf(t):
@ -50,8 +49,8 @@ class MetaConverter:
base = self.meta_tensor(t._base)
def is_c_of_r(complex_dtype, real_dtype):
return is_complex_dtype(complex_dtype) and \
corresponding_real_dtype(complex_dtype) == real_dtype
return utils.is_complex_dtype(complex_dtype) and \
utils.corresponding_real_dtype(complex_dtype) == real_dtype
if base.dtype == t.dtype:
pass
@ -138,3 +137,5 @@ class MetaConverter:
else:
# non-Tensor types don't count as hit or miss
return t
import torch._prims.utils as utils

View file

@ -226,6 +226,13 @@ void initDispatchBindings(PyObject* module) {
return states;
});
m.def("_dispatch_tls_set_dispatch_key_excluded", [](const char* dispatch_key, bool desired_state) {
c10::impl::tls_set_dispatch_key_excluded(c10::parseDispatchKey(dispatch_key), desired_state);
});
m.def("_dispatch_tls_is_dispatch_key_excluded", [](const char* dispatch_key) {
return c10::impl::tls_is_dispatch_key_excluded(c10::parseDispatchKey(dispatch_key));
});
// Prints out the name of every operator that has a kernel registered to the Dispatcher
// under [dispatch_key].
// If no arguments are specified, it'll print out the name of every operator that the Dispatcher knows of.

View file

@ -19416,6 +19416,12 @@ def _inherit_constructor_args(name, op, inherited, overrides):
kwargs.update(common_kwargs)
kwargs.update(overrides)
kwargs['supports_autograd'] = False
kwargs['supports_gradgrad'] = False
kwargs['supports_fwgrad_bwgrad'] = False
kwargs['supports_inplace_autograd'] = False
kwargs['supports_forward_ad'] = False
return kwargs
class PythonRefInfo(OpInfo):
@ -19755,10 +19761,6 @@ python_ref_db = [
PythonRefInfo(
"_refs.nn.functional.leaky_relu",
torch_opinfo_name="nn.functional.leaky_relu",
decorators=(
# Need FakeTensor support for meta coverage
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',),
),
),
ElementwiseUnaryPythonRefInfo(
"_refs.nn.functional.relu",
@ -20078,11 +20080,6 @@ python_ref_db = [
PythonRefInfo(
"_refs.stack",
torch_opinfo_name="stack",
skips=(
# https://github.com/pytorch/pytorch/issues/77046
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
),
),
PythonRefInfo(
"_refs.squeeze",
@ -20105,10 +20102,6 @@ python_ref_db = [
PythonRefInfo(
"_refs.t",
torch_opinfo_name="t",
decorators=(
# Need FakeTensor support for meta coverage
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',),
),
),
PythonRefInfo(
"_refs.unsqueeze",
@ -20180,8 +20173,6 @@ python_ref_db = [
"_refs.addr",
torch_opinfo_name="addr",
decorators=(
# RuntimeError: no _refs support for torch.outer
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',),
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',),
),
),