diff --git a/test/test_ops.py b/test/test_ops.py index 26e2f436c1d..8740609ef38 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -14,6 +14,7 @@ from torch.testing._internal.common_dtype import ( floating_and_complex_types_and, all_types_and_complex_and, ) +from torch._subclasses.fake_tensor import FakeTensor from torch.testing._internal.common_utils import ( TestCase, is_iterable_of_tensors, @@ -356,7 +357,7 @@ class TestCommon(TestCase): def _to_tensormeta(x): if isinstance(x, torch.Tensor): - return prims.utils.TensorMeta(x) + return FakeTensor.from_tensor(x) return x # TODO: iterate over requires_grad true/false @@ -506,7 +507,7 @@ class TestCommon(TestCase): def test_python_ref_errors(self, device, op): def _to_tensormeta(x): if isinstance(x, torch.Tensor): - return prims.utils.TensorMeta(x) + return FakeTensor.from_tensor(x) return x error_inputs = op.error_inputs(device) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index e8bed24bd16..24034bf8ec8 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -840,6 +840,8 @@ class Generator(object): def _dispatch_library(kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0) -> Any: ... def _dispatch_has_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ... def _dispatch_has_kernel(name: str) -> _bool: ... +def _dispatch_tls_is_dispatch_key_excluded(dispatch: str) -> _bool: ... +def _dispatch_tls_set_dispatch_key_excluded(dispatch: str, val: _bool) -> None: ... # Defined in torch/csrc/utils/init.cpp class BenchmarkConfig(object): diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index dd219d12abe..6510ab3cd1d 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -5,7 +5,6 @@ import torch._prims.utils as utils from torch._prims.utils import ( TensorLike, TensorLikeType, - TensorMeta, ShapeType, getnvFuserDtype, DimsType, @@ -13,11 +12,14 @@ from torch._prims.utils import ( StrideType, Number, NumberType, + TensorMeta, ) from torch.overrides import has_torch_function, handle_torch_function import torch.library -from torch.utils._pytree import tree_map +from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten +from torch._subclasses.fake_tensor import FakeTensor +import contextlib from typing import Sequence, Optional, Union, Callable, List, Tuple, Any, Type from functools import reduce, partial from enum import Enum @@ -26,6 +28,7 @@ import math prim = torch.library.Library("prims", "DEF") prim_impl = torch.library.Library("prims", "IMPL", "CompositeExplicitAutograd") +prim_autograd_impl = torch.library.Library("prims", "IMPL", "Autograd") prim_meta_impl = torch.library.Library("prims", "IMPL", "Meta") # Experimental module containing prototype "primitive" operations. @@ -286,27 +289,29 @@ class RETURN_TYPE(Enum): def _wrap_tensor_meta(f): def wrap(t): if isinstance(t, torch.Tensor): - return TensorMeta(t) - else: - return t - - def unwrap(t): - # TODO: doesn't setup aliasing relation on views correctly - if isinstance(t, TensorMeta): - return torch.empty_strided( - t.shape, t.stride(), dtype=t.dtype, device="meta" - ) + return FakeTensor.from_tensor(t) else: return t def wrapper(*args, **kwargs): wrapped_args = tree_map(wrap, args) wrapped_kwargs = tree_map(wrap, kwargs) - return tree_map(unwrap, f(*wrapped_args, **wrapped_kwargs)) + return f(*wrapped_args, **wrapped_kwargs) return wrapper +@contextlib.contextmanager +def _DispatchBelowAutograd(): + # TODO: AutogradOther + old = torch._C._dispatch_tls_is_dispatch_key_excluded("AutogradFunctionality") + torch._C._dispatch_tls_set_dispatch_key_excluded("AutogradFunctionality", True) + try: + yield + finally: + torch._C._dispatch_tls_set_dispatch_key_excluded("AutogradFunctionality", old) + + def _make_prim( *, schema: str, @@ -330,16 +335,33 @@ def _make_prim( meta(*args, **kwargs) return impl_aten(*args, **kwargs) + class BackwardsNotSupported(torch.autograd.Function): + @staticmethod + def forward(ctx, args_spec, *flat_args): + args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type] + with _DispatchBelowAutograd(): + return _prim(*args, **kwargs) + + @staticmethod + def backward(ctx, *args): + raise RuntimeError("backwards not supported on prim") + + def _autograd_impl(*args, **kwargs): + flat_args, args_spec = tree_flatten((args, kwargs)) + return BackwardsNotSupported.apply(args_spec, *flat_args) + name = schema.split("(")[0] prim_impl.impl(name, _prim_impl) + prim_autograd_impl.impl(name, _autograd_impl) prim_meta_impl.impl(name, _wrap_tensor_meta(meta)) - _prim = getattr(torch.ops.prims, name).default + _prim_packet = getattr(torch.ops.prims, name) + _prim = _prim_packet.default - _prim.__doc__ = doc - _prim.meta = meta # type: ignore[attr-defined] - _prim.impl_nvfuser = impl_nvfuser # type: ignore[attr-defined] - _prim.return_type = return_type # type: ignore[attr-defined] + for p in (_prim_packet, _prim): + p.__doc__ = doc + p.impl_nvfuser = impl_nvfuser # type: ignore[attr-defined] + p.return_type = return_type # type: ignore[attr-defined] return _prim @@ -355,7 +377,7 @@ def _elementwise_meta( *args, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, args_with_fixed_dtypes: Tuple[TensorLikeType, ...] = None, -) -> TensorMeta: +) -> FakeTensor: """ Meta function for elementwise operations that produce outputs in the same dtype as their inputs. @@ -1926,7 +1948,7 @@ device_put = _make_prim( # NOTE: need to model meta scalars # See https://github.com/pytorch/pytorch/issues/78070 -def _item_meta(a: TensorLikeType) -> TensorMeta: +def _item_meta(a: TensorLikeType) -> FakeTensor: number_type = utils.dtype_to_type(a.dtype) return TensorMeta(number_type(-1)) @@ -1948,7 +1970,7 @@ item = _make_prim( # NOTE: need to model meta scalars # See https://github.com/pytorch/pytorch/issues/78070 -def _maximum_value_meta(dtype: torch.dtype) -> TensorMeta: +def _maximum_value_meta(dtype: torch.dtype) -> FakeTensor: number_type = utils.dtype_to_type(dtype) return TensorMeta(number_type(-1)) @@ -1980,7 +2002,7 @@ maximum_value = _make_prim( # NOTE: need to model meta scalars # See https://github.com/pytorch/pytorch/issues/78070 -def _minimum_value_meta(dtype: torch.dtype) -> TensorMeta: +def _minimum_value_meta(dtype: torch.dtype) -> FakeTensor: number_type = utils.dtype_to_type(dtype) return TensorMeta(number_type(-1)) diff --git a/torch/_prims/context.py b/torch/_prims/context.py index f6db24b0770..be47030ab91 100644 --- a/torch/_prims/context.py +++ b/torch/_prims/context.py @@ -1,14 +1,11 @@ -import string from typing import Callable, Sequence, Any, Dict -from itertools import chain import functools import torch -from torch.fx.graph import Graph, Node import torch.overrides -from torch._prims.utils import TensorMeta, torch_function_passthrough +from torch._prims.utils import torch_function_passthrough import torch._refs as refs import torch._refs @@ -32,143 +29,6 @@ _torch_to_reference_map = { } -class PrimContext(torch.overrides.TorchFunctionMode): - """ - The prototype prim tracing context. - - Example usage: - - import torch._prims.utils as utils - from torch._prims.context import PrimContext - from torch._prims.executor import execute - from torch.overrides import push_torch_function_mode - - a = torch.randn((2, 2)) - b = torch.randn((2, 2)) - - with push_torch_function_mode(PrimContext): - meta_a = ctx.placeholder(utils.TensorMeta(a)) - meta_b = ctx.placeholder(utils.TensorMeta(b)) - result = torch.add(meta_a, meta_b) - ctx.output(result) - - exc_result = execute(ctx, a, b) - - Currently this only acquires a trace of prims, and - it does not account for control flow. As such, - execute must be called with tensors that have the - same metadata (dtype, device, shape...) as - the tensors used to trace the operations. - - The tracing context's FX graph can be acquired - using its graph attribute. - """ - - def __init__(self): - self.graph = Graph() - - # Private attributes for generating names - self._tensor_name_counter = 0 - self._dim_name_counter = 0 - self._shape_name_counter = 0 - self._lowercase = tuple(string.ascii_lowercase) - self._uppercase = tuple(string.ascii_uppercase) - - @staticmethod - def _create_name(idx, chars): - name = "" - while idx >= len(chars): - name = chars[idx % len(chars)] + name - idx = idx - len(chars) - name = chars[idx] + name - - return name - - def _tensor_name(self): - idx = self._tensor_name_counter - self._tensor_name_counter = self._tensor_name_counter + 1 - - return self._create_name(idx, self._lowercase) - - def _add_user(self, tm: TensorMeta, node: Node) -> None: - assert tm.node is not None - tm.node.users[node] = None - - def placeholder(self, a: Any): - name = self._tensor_name() - node = self.graph.placeholder(name) - - if isinstance(a, TensorMeta): - if a.node is not None: - raise ValueError("Attempting to reuse a TensorMeta in a new trace!") - a.tname = name - a.node = node - - return a - - def output(self, tms: Sequence[TensorMeta]): - # TODO: allow other output types - flat_tms, _ = torch.utils._pytree.tree_flatten(tms) - for tm in flat_tms: - assert isinstance(tm, TensorMeta), f"Got non-TensorMeta output!, {type(tm)}" - - node = self.graph.output(tms) - for tm in flat_tms: - self._add_user(tm, node) - - def __torch_function__( - self, - func: Callable, - types: Sequence, - args: Sequence[Any] = (), - kwargs: Dict = None, - ): - """ - Determines which function to call. The order of which - function is called is determined by: - - - func's "meta" attribute, if it exists - - if func is a torch operation, its corresponding reference - - func - """ - - if kwargs is None: - kwargs = {} - - if hasattr(func, "meta"): - # TODO: add check that all args/kwargs are 'registered' properly - # to this trace - - output = func.meta(*args, **kwargs) # type: ignore[attr-defined] - - # Updates graph - # TODO: handle outputs with multiple tensors - # TODO: handle non-tensor outputs - assert isinstance(output, TensorMeta) - output_name = self._tensor_name() - node = self.graph.create_node( - "call_function", func, name=output_name, args=args, kwargs=kwargs - ) - output.tname = output_name - output.node = node - - # Marks uses - for x in ( - x for x in chain(args, kwargs.values()) if isinstance(x, TensorMeta) - ): - self._add_user(x, node) - - return output - - # Remaps torch operations to their references - if func in _torch_to_reference_map: - fn = _torch_to_reference_map[func] - with torch.overrides.enable_torch_function_mode(self, replace=self.inner): - return fn(*args, **kwargs) # type: ignore[operator] - - return func(*args, **kwargs) - - @functools.lru_cache(None) def torch_to_refs_map(): """ @@ -183,7 +43,7 @@ def torch_to_refs_map(): ] r = {} for mod_torch, mod_refs in modules: - for s in mod_refs.__all__: + for s in mod_refs.__all__: # type: ignore[attr-defined] r[mod_torch.__dict__.get(s)] = mod_refs.__dict__.get(s) return r diff --git a/torch/_prims/executor.py b/torch/_prims/executor.py index dc878bb904c..dc6bbbefaa8 100644 --- a/torch/_prims/executor.py +++ b/torch/_prims/executor.py @@ -3,15 +3,17 @@ from typing import Callable import torch from torch.fx import GraphModule -from torch._prims.utils import TensorMeta, getnvFuserDtype -from torch._prims.context import PrimContext +from torch.fx.experimental.proxy_tensor import make_fx +from torch._prims.utils import getnvFuserDtype +from torch._prims.context import TorchRefsMode import torch.overrides +from torch.utils._pytree import tree_map if torch.cuda.is_available(): from torch._C._nvfuser import Fusion, FusionDefinition # type: ignore[import] -def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs): +def execute(gm: GraphModule, *args, executor: str = "aten", **kwargs): """ Prototype ATen executor. @@ -19,7 +21,6 @@ def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs): """ if executor == "aten": - gm = GraphModule({}, ctx.graph) return gm.forward(*args, **kwargs) elif executor == "nvfuser": if not torch.cuda.is_available(): @@ -28,36 +29,32 @@ def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs): ) # PROTOTYPE nvfuser executor - # Only accepts tensor inputs and single tensor outputs - # Does not handle kwargs - # Does not support reusing the same ctx to execute! - assert len(kwargs) == 0 - # TODO: make this a proper trace -> trace transform that - # doesn't mutate the context - graph_fd = ctx.graph.placeholder("fd") - ctx.graph._root.append(graph_fd) + # Everything in the graph must support nvfuser fusion = Fusion() with FusionDefinition(fusion) as fd: - # Transforms graph to call nvfuser lowerings - nv_args = [fd] - for arg in args: + + class FusionInterpreter(torch.fx.Interpreter): + def call_function(self, target, args, kwargs): + target = target.impl_nvfuser + args = (fd,) + args + return target(*args, **kwargs) + + def to_nv(arg): if isinstance(arg, torch.Tensor): x = fd.define_tensor( arg.size(), arg.stride(), getnvFuserDtype(arg.dtype) ) fd.add_input(x) - nv_args.append(x) + return x else: - nv_args.append(x) + return arg - for x in ctx.graph.nodes: - if x.op == "call_function": - x.target = x.target.impl_nvfuser - x.args = (graph_fd,) + x.args + # Transforms graph to call nvfuser lowerings + nv_args = tree_map(to_nv, args) + nv_kwargs = tree_map(to_nv, kwargs) - gm = GraphModule({}, ctx.graph) - out = gm.forward(*nv_args) + out = FusionInterpreter(gm).run(*nv_args, **nv_kwargs) flat_out, unflatten_spec = torch.utils._pytree.tree_flatten(out) for o in flat_out: fd.add_output(o) @@ -102,17 +99,9 @@ def make_traced(fn: Callable): """ def _traced(*args, executor="aten"): - ctx: PrimContext - with torch.overrides.push_torch_function_mode(PrimContext) as ctx: # type: ignore[attr-defined, assignment] - placeholders = [] - for arg in args: - if isinstance(arg, torch.Tensor): - placeholders.append(ctx.placeholder(TensorMeta(arg))) - else: - placeholders.append(ctx.placeholder(arg)) - - result = fn(*placeholders) - ctx.output(result) - return execute(ctx, *args, executor=executor) + # TODO: caching + with TorchRefsMode.push(): + gm = make_fx(fn)(*args) + return execute(gm, *args, executor=executor) return _traced diff --git a/torch/_prims/utils.py b/torch/_prims/utils.py index 2a2d8f1d295..e752b277e5c 100644 --- a/torch/_prims/utils.py +++ b/torch/_prims/utils.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Any, Union, Sequence, Optional, Callable, Dict, Tuple, List +from typing import Any, Union, Sequence, Optional, Tuple, List from enum import Enum from functools import reduce, cmp_to_key import operator +from torch._subclasses.fake_tensor import FakeTensor import torch @@ -55,119 +56,57 @@ torch_function_passthrough = { } -class TensorMeta(torch.Tensor): - """ - Model tensor metadata. Not a stock meta tensor because device is modeled - as the original device (not meta device), also we have different behavior - for some high level Python bindings - """ - - # Note: this will be an fx Node if it's ever - # populated, but some Meta-internal jobs don't include fx - node: Optional[Any] - tname: str - - @staticmethod - def __new__( - cls, - tensorlike: Optional[Union[TensorMeta, NumberType, torch.Tensor]] = None, - *, - shape: Optional[ShapeType] = None, - strides: Optional[StrideType] = None, - dtype: Optional[torch.dtype] = None, - device: Optional[Union[torch.device, str]] = None, - ): - - if isinstance(tensorlike, Number): - assert not shape and (shape is None or isinstance(shape, Sequence)) - assert not strides and (strides is None or isinstance(strides, Sequence)) - inferred_shape: Tuple[int, ...] = () - inferred_strides: Tuple[int, ...] = () - inferred_dtype = type_to_dtype(type(tensorlike)) - inferred_device = torch.device("cpu") - # TODO: This looks wrong, a number that is wrapped into a tensor - # needs to behave differently than a scalar tensor for type - # promotion purposes - elif tensorlike is not None: - assert isinstance(tensorlike, (TensorMeta, torch.Tensor)) - inferred_shape = tuple(tensorlike.shape) - inferred_strides = tuple(tensorlike.stride()) - inferred_dtype = tensorlike.dtype - inferred_device = tensorlike.device - else: - # If no tensorlike "example" is given then all metadata - # must be provided explicitly - assert shape is not None - assert strides is not None - assert dtype is not None - assert device is not None - - shape = inferred_shape if shape is None else tuple(shape) - strides = inferred_strides if strides is None else tuple(strides) - dtype = inferred_dtype if dtype is None else dtype - device = inferred_device if device is None else device - - if isinstance(device, str): - device = torch.device(device) - - r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] - cls, - shape, - strides=strides, - storage_offset=0, # TODO: this is inaccurate - dtype=dtype, - device=device, - requires_grad=False, - ) - - r.tname = "" - r.node = None - return r - - @classmethod - def __torch_function__( - cls, - func: Callable, - types: Sequence, - args: Sequence[Any] = (), - kwargs: Optional[Dict] = None, - ): - if kwargs is None: - kwargs = {} - - if func in torch_function_passthrough: - return super().__torch_function__(func, types, args, kwargs) - - if not hasattr(func, "meta"): - raise ValueError(f"Callable {func} has no meta function!") - - return func.meta(*args, **kwargs) # type: ignore[attr-defined] - - @classmethod - def __torch_dispatch__( - cls, - func, - types, - args=(), - kwargs=None, - ): - raise RuntimeError("this should be unreachable") - - # TODO: fx uses dunder repr to print objects in code - def __repr__(self): - return self.tname - # return f"TensorMeta(dtype={self.dtype}, device={self.device}, shape={self.shape}, strides={self.stride()})" - - def __format__(self, format_spec): - return self.tname - - -TensorLikeType = Union[torch.Tensor, TensorMeta] -TensorLike = (torch.Tensor, TensorMeta) +TensorLikeType = torch.Tensor +TensorLike = torch.Tensor TensorSequenceType = Union[List[TensorLikeType], Tuple[TensorLikeType, ...]] TensorOrNumberLikeType = Union[TensorLikeType, NumberType] +def TensorMeta( + tensorlike: Optional[Union[NumberType, torch.Tensor]] = None, + *, + shape: Optional[ShapeType] = None, + strides: Optional[StrideType] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str]] = None, +): + if isinstance(tensorlike, Number): + assert not shape and (shape is None or isinstance(shape, Sequence)) + assert not strides and (strides is None or isinstance(strides, Sequence)) + inferred_shape: Tuple[int, ...] = () + inferred_strides: Tuple[int, ...] = () + inferred_dtype = type_to_dtype(type(tensorlike)) + inferred_device = torch.device("cpu") + # TODO: This looks wrong, a number that is wrapped into a tensor + # needs to behave differently than a scalar tensor for type + # promotion purposes + elif tensorlike is not None: + assert isinstance(tensorlike, torch.Tensor) + inferred_shape = tuple(tensorlike.shape) + inferred_strides = tuple(tensorlike.stride()) + inferred_dtype = tensorlike.dtype + inferred_device = tensorlike.device + else: + # If no tensorlike "example" is given then all metadata + # must be provided explicitly + assert shape is not None + assert strides is not None + assert dtype is not None + assert device is not None + + shape = inferred_shape if shape is None else tuple(shape) + strides = inferred_strides if strides is None else tuple(strides) + dtype = inferred_dtype if dtype is None else dtype + device = inferred_device if device is None else device + + if isinstance(device, str): + device = torch.device(device) + + return FakeTensor( + torch.empty_strided(shape, strides, dtype=dtype, device="meta"), device + ) + + # TODO: look at using torch.testing.assert_close instead with an option # to just compare metadata def compare_tensor_meta(a: TensorLikeType, b: TensorLikeType): diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index cfe64efc3a6..2a20e88bb4e 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -36,6 +36,7 @@ import operator import warnings import math from enum import Enum +import collections # Experimental module containing prototype Python references for existing # PyTorch operations. @@ -1576,10 +1577,14 @@ def addr( def atleast_1d( - *args: TensorLikeType, + arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: """Reference implementation of :func:`torch.atleast_1d`.""" - args_ = args[0] if len(args) == 1 and not torch.is_tensor(args[0]) else args + if not args and isinstance(arg, collections.Sequence): + args_ = arg + else: + assert not isinstance(arg, collections.Sequence) + args_ = (arg,) + args res = tuple(a if a.ndim >= 1 else unsqueeze(a, 0) for a in args_) return res if len(res) > 1 else res[0] @@ -1595,20 +1600,28 @@ def _unsqueeze_atleast( def atleast_2d( - *args: TensorLikeType, + arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: """Reference implementation of :func:`torch.atleast_2d`.""" - args_ = args[0] if len(args) == 1 and not torch.is_tensor(args[0]) else args + if not args and isinstance(arg, collections.Sequence): + args_ = arg + else: + assert not isinstance(arg, collections.Sequence) + args_ = (arg,) + args unsqueeze_atleast_1d = partial(_unsqueeze_atleast, atleast_1d, 0) res = tuple(a if a.ndim >= 2 else unsqueeze_atleast_1d(a) for a in args_) return res if len(res) > 1 else res[0] def atleast_3d( - *args: TensorLikeType, + arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: """Reference implementation of :func:`torch.atleast_3d`.""" - args_ = args[0] if len(args) == 1 and not torch.is_tensor(args[0]) else args + if not args and isinstance(arg, collections.Sequence): + args_ = arg + else: + assert not isinstance(arg, collections.Sequence) + args_ = (arg,) + args unsqueeze_atleast_2d = partial(_unsqueeze_atleast, atleast_2d, -1) res = tuple(a if a.ndim >= 3 else unsqueeze_atleast_2d(a) for a in args_) return res if len(res) > 1 else res[0] diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 332f08a12de..88440a68362 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -180,7 +180,6 @@ class FakeTensor(torch.Tensor): # elem does not need to be recorded, because FakeTensor *is a* elem assert elem.device.type == "meta" device = device if isinstance(device, torch.device) else torch.device(device) - assert device.type != "meta" self.fake_device = device @staticmethod diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 6450c613016..0508d8f4774 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -1,5 +1,4 @@ import torch -from torch._prims.utils import is_complex_dtype, corresponding_real_dtype from torch.utils._mode_utils import no_dispatch def safe_is_leaf(t): @@ -50,8 +49,8 @@ class MetaConverter: base = self.meta_tensor(t._base) def is_c_of_r(complex_dtype, real_dtype): - return is_complex_dtype(complex_dtype) and \ - corresponding_real_dtype(complex_dtype) == real_dtype + return utils.is_complex_dtype(complex_dtype) and \ + utils.corresponding_real_dtype(complex_dtype) == real_dtype if base.dtype == t.dtype: pass @@ -138,3 +137,5 @@ class MetaConverter: else: # non-Tensor types don't count as hit or miss return t + +import torch._prims.utils as utils diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index eeb5b02622e..90560e21941 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -226,6 +226,13 @@ void initDispatchBindings(PyObject* module) { return states; }); + m.def("_dispatch_tls_set_dispatch_key_excluded", [](const char* dispatch_key, bool desired_state) { + c10::impl::tls_set_dispatch_key_excluded(c10::parseDispatchKey(dispatch_key), desired_state); + }); + m.def("_dispatch_tls_is_dispatch_key_excluded", [](const char* dispatch_key) { + return c10::impl::tls_is_dispatch_key_excluded(c10::parseDispatchKey(dispatch_key)); + }); + // Prints out the name of every operator that has a kernel registered to the Dispatcher // under [dispatch_key]. // If no arguments are specified, it'll print out the name of every operator that the Dispatcher knows of. diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index d934c28f927..e7c1c4c88d6 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -19416,6 +19416,12 @@ def _inherit_constructor_args(name, op, inherited, overrides): kwargs.update(common_kwargs) kwargs.update(overrides) + kwargs['supports_autograd'] = False + kwargs['supports_gradgrad'] = False + kwargs['supports_fwgrad_bwgrad'] = False + kwargs['supports_inplace_autograd'] = False + kwargs['supports_forward_ad'] = False + return kwargs class PythonRefInfo(OpInfo): @@ -19755,10 +19761,6 @@ python_ref_db = [ PythonRefInfo( "_refs.nn.functional.leaky_relu", torch_opinfo_name="nn.functional.leaky_relu", - decorators=( - # Need FakeTensor support for meta coverage - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',), - ), ), ElementwiseUnaryPythonRefInfo( "_refs.nn.functional.relu", @@ -20078,11 +20080,6 @@ python_ref_db = [ PythonRefInfo( "_refs.stack", torch_opinfo_name="stack", - skips=( - # https://github.com/pytorch/pytorch/issues/77046 - DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), - DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), - ), ), PythonRefInfo( "_refs.squeeze", @@ -20105,10 +20102,6 @@ python_ref_db = [ PythonRefInfo( "_refs.t", torch_opinfo_name="t", - decorators=( - # Need FakeTensor support for meta coverage - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',), - ), ), PythonRefInfo( "_refs.unsqueeze", @@ -20180,8 +20173,6 @@ python_ref_db = [ "_refs.addr", torch_opinfo_name="addr", decorators=( - # RuntimeError: no _refs support for torch.outer - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',), ), ),