mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[BE] update type annotations for basic utilities in torch/__init__.py (#129001)
Changes: 1. Make some arguments positional-only as we only support Python 3.8+ 2. Clean up `torch.typename(obj)` implementation. 3. Update type annotations., especially `is_tensor()` and `is_masked_tensor()` using `TypeGuard`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/129001 Approved by: https://github.com/malfet
This commit is contained in:
parent
1a54bb0f96
commit
93a33bf3ac
22 changed files with 288 additions and 196 deletions
|
|
@ -1188,9 +1188,6 @@ exclude_patterns = [
|
|||
'torch/_export/serde/upgrade.py',
|
||||
'torch/_export/trace.py',
|
||||
'torch/_export/verifier.py',
|
||||
'torch/_higher_order_ops/__init__.py',
|
||||
'torch/_higher_order_ops/out_dtype.py',
|
||||
'torch/_higher_order_ops/wrap.py',
|
||||
'torch/_vendor/**',
|
||||
'torch/ao/__init__.py',
|
||||
'torch/ao/nn/__init__.py',
|
||||
|
|
|
|||
|
|
@ -1036,7 +1036,7 @@ class FakeTensorConstHandling(TestCase):
|
|||
make_propagate_real_tensors_cls(FakeTensorConstHandling)
|
||||
|
||||
|
||||
def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type):
|
||||
def contains_type(type: torch.Type, maybe_contained_type: torch.Type):
|
||||
return maybe_contained_type.isSubtypeOf(type) or any(
|
||||
contains_type(e, maybe_contained_type) for e in type.containedTypes()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,19 +1,26 @@
|
|||
# Owner(s): ["module: autograd"]
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, IS_JETSON, IS_WINDOWS, IS_MACOS, skipIfTorchDynamo
|
||||
from torch._utils_internal import get_file_path_2
|
||||
|
||||
import pkgutil
|
||||
import torch
|
||||
import importlib
|
||||
from typing import Callable
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import pkgutil
|
||||
import unittest
|
||||
from importlib import import_module
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch._utils_internal import get_file_path_2
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_JETSON,
|
||||
IS_MACOS,
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
def _find_all_importables(pkg):
|
||||
"""Find all importables in the project.
|
||||
|
|
@ -56,6 +63,19 @@ def _discover_path_importables(pkg_pth, pkg_name):
|
|||
|
||||
|
||||
class TestPublicBindings(TestCase):
|
||||
def test_no_new_reexport_callables(self):
|
||||
"""
|
||||
This test aims to stop the introduction of new re-exported callables into
|
||||
torch whose names do not start with _. Such callables are made available as
|
||||
torch.XXX, which may not be desirable.
|
||||
"""
|
||||
reexported_callables = sorted(
|
||||
k
|
||||
for k, v in vars(torch).items()
|
||||
if callable(v) and not v.__module__.startswith('torch')
|
||||
)
|
||||
self.assertTrue(all(k.startswith('_') for k in reexported_callables), reexported_callables)
|
||||
|
||||
def test_no_new_bindings(self):
|
||||
"""
|
||||
This test aims to stop the introduction of new JIT bindings into torch._C
|
||||
|
|
@ -278,7 +298,6 @@ class TestPublicBindings(TestCase):
|
|||
return False
|
||||
return True
|
||||
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS or IS_MACOS, "Inductor/Distributed modules hard fail on windows and macos")
|
||||
@skipIfTorchDynamo("Broken and not relevant for now")
|
||||
def test_modules_can_be_imported(self):
|
||||
|
|
@ -289,7 +308,7 @@ class TestPublicBindings(TestCase):
|
|||
# which calls sys.exit() when we try to import it
|
||||
if "__main__" in modname:
|
||||
continue
|
||||
import_module(modname)
|
||||
importlib.import_module(modname)
|
||||
except Exception as e:
|
||||
# Some current failures are not ImportError
|
||||
failures.append((modname, type(e)))
|
||||
|
|
|
|||
|
|
@ -1086,12 +1086,12 @@ def gen_pyi(
|
|||
"def __init__(self, other: Tensor) -> None: ...",
|
||||
f"def __init__(self, size: _size, *, {DEVICE_PARAM}) -> None: ...",
|
||||
],
|
||||
"as_subclass": ["def as_subclass(self, cls: Type[S]) -> S: ..."],
|
||||
"as_subclass": ["def as_subclass(self, cls: _Type[S]) -> S: ..."],
|
||||
"_make_subclass": [
|
||||
"@staticmethod \ndef _make_subclass({}) -> S: ...".format(
|
||||
", ".join(
|
||||
[
|
||||
"cls: Type[S]",
|
||||
"cls: _Type[S]",
|
||||
"data: Tensor",
|
||||
"require_grad: _bool = False",
|
||||
"dispatch_strides: _bool = False",
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from typing import (
|
|||
Set,
|
||||
SupportsIndex,
|
||||
Tuple,
|
||||
Type,
|
||||
Type as _Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
|
|
@ -35,12 +35,17 @@ from typing_extensions import ParamSpec, Self
|
|||
import numpy
|
||||
|
||||
import torch
|
||||
from torch import inf, SymInt, Tensor
|
||||
from torch import SymInt, Tensor, inf
|
||||
from torch._prims_common import DeviceLikeType
|
||||
from torch.autograd.graph import Node as _Node
|
||||
from torch.package import PackageExporter
|
||||
from torch.storage import UntypedStorage, TypedStorage
|
||||
from torch.storage import TypedStorage, UntypedStorage
|
||||
from torch.types import (
|
||||
Device,
|
||||
Number,
|
||||
Storage,
|
||||
_bool,
|
||||
_bytes,
|
||||
_complex,
|
||||
_device,
|
||||
_dispatchkey,
|
||||
|
|
@ -50,17 +55,23 @@ from torch.types import (
|
|||
_layout,
|
||||
_qscheme,
|
||||
_size,
|
||||
Device,
|
||||
Number,
|
||||
Storage,
|
||||
_str,
|
||||
)
|
||||
|
||||
from torch._prims_common import DeviceLikeType
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
# This module is defined in torch/csrc/Module.cpp
|
||||
from . import (
|
||||
_aoti,
|
||||
_cpu,
|
||||
_functorch,
|
||||
_lazy,
|
||||
_lazy_ts_backend,
|
||||
_nn,
|
||||
_onnx,
|
||||
_VariableFunctions,
|
||||
_verbose,
|
||||
)
|
||||
|
||||
from . import _functorch, _lazy, _lazy_ts_backend, _nn, _onnx, _VariableFunctions, _cpu, _aoti, _verbose
|
||||
# This module is defined in torch/csrc/Module.cpp
|
||||
|
||||
K = TypeVar("K")
|
||||
T = TypeVar("T")
|
||||
|
|
@ -1105,7 +1116,7 @@ class Module: ...
|
|||
def _initExtension(shm_manager_path: str) -> None: ... # THPModule_initExtension
|
||||
def _autograd_init() -> _bool: ... # THPAutograd_initExtension
|
||||
def _add_docstr(obj: T, doc_obj: str) -> T: ... # THPModule_addDocStr
|
||||
def _init_names(arg: Sequence[Type]) -> None: ... # THPModule_initNames
|
||||
def _init_names(arg: Sequence[_Type]) -> None: ... # THPModule_initNames
|
||||
def _has_distributed() -> _bool: ... # THPModule_hasDistributed
|
||||
def _set_default_tensor_type(type) -> None: ... # THPModule_setDefaultTensorType
|
||||
def _set_default_dtype(d: _dtype) -> None: ... # THPModule_setDefaultDtype
|
||||
|
|
@ -1235,13 +1246,13 @@ def _log_api_usage_metadata(event: str, metadata_map: Dict[str, str]) -> None: .
|
|||
def _demangle(str) -> str: ... # c10::demangle
|
||||
def _disabled_torch_function_impl(
|
||||
func: Callable,
|
||||
types: Iterable[Type],
|
||||
types: Iterable[_Type],
|
||||
args: Tuple,
|
||||
kwargs: Dict,
|
||||
) -> Any: ... # THPModule_disable_torch_function
|
||||
def _disabled_torch_dispatch_impl(
|
||||
func: Callable,
|
||||
types: Iterable[Type],
|
||||
types: Iterable[_Type],
|
||||
args: Tuple,
|
||||
kwargs: Dict,
|
||||
) -> Any: ... # THPModule_disable_dispatch_function
|
||||
|
|
@ -1455,7 +1466,7 @@ def _get_privateuse1_backend_name() -> str: ...
|
|||
class Generator:
|
||||
device: _device
|
||||
def __init__(self, device: Optional[DeviceLikeType] = None) -> None: ...
|
||||
def __reduce__(self) -> Tuple[Type[Generator], Tuple[_device], Tuple[_int, Optional[_int], Tensor]]: ...
|
||||
def __reduce__(self) -> Tuple[_Type[Generator], Tuple[_device], Tuple[_int, Optional[_int], Tensor]]: ...
|
||||
def __setstate__(self, state: Tuple[_int, Optional[_int], Tensor]) -> None: ...
|
||||
def get_state(self) -> Tensor: ...
|
||||
def set_state(self, _new_state: Tensor) -> Generator: ...
|
||||
|
|
@ -2146,6 +2157,24 @@ class InferredType:
|
|||
|
||||
R = TypeVar("R", bound=JitType)
|
||||
|
||||
class Type(JitType):
|
||||
def str(self) -> _str: ...
|
||||
def containedTypes(self) -> List[JitType]: ...
|
||||
def dim(self) -> Optional[_int]: ...
|
||||
def undefined(self) -> Optional[_bool]: ...
|
||||
def sizes(self) -> Optional[List[_int]]: ...
|
||||
def symbol_sizes(self) -> Optional[List[_int]]: ...
|
||||
def varyingSizes(self) -> Optional[List[Optional[_int]]]: ...
|
||||
def strides(self) -> Optional[List[_int]]: ...
|
||||
def contiguous(self) -> Self: ...
|
||||
def device(self) -> Optional[_device]: ...
|
||||
def __eq__(self, other: object) -> _bool: ...
|
||||
__hash__ = None # type: ignore[assignment]
|
||||
def is_interface_type(self) -> _bool: ...
|
||||
def requires_grad(self) -> _bool: ...
|
||||
@property
|
||||
def annotation_string(self) -> _str: ...
|
||||
|
||||
class AnyType(JitType):
|
||||
@staticmethod
|
||||
def get() -> AnyType: ...
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ import builtins
|
|||
import ctypes
|
||||
import glob
|
||||
import importlib
|
||||
import importlib.util
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
|
|
@ -22,13 +21,24 @@ import platform
|
|||
import sys
|
||||
import textwrap
|
||||
import threading
|
||||
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, TYPE_CHECKING, Union
|
||||
from typing import (
|
||||
Any as _Any,
|
||||
Callable as _Callable,
|
||||
Dict as _Dict,
|
||||
Optional as _Optional,
|
||||
Set as _Set,
|
||||
Tuple as _Tuple,
|
||||
Type as _Type,
|
||||
TYPE_CHECKING,
|
||||
Union as _Union,
|
||||
)
|
||||
from typing_extensions import TypeGuard as _TypeGuard
|
||||
|
||||
|
||||
# multipy/deploy is setting this import before importing torch, this is the most
|
||||
# reliable way we have to detect if we're running within deploy.
|
||||
# https://github.com/pytorch/multipy/blob/d60f34ad38c371e441fe7ffdb77a3c3dda5a5d19/multipy/runtime/interpreter/interpreter_impl.cpp#L134-L137
|
||||
def _running_with_deploy():
|
||||
def _running_with_deploy() -> builtins.bool:
|
||||
return sys.modules.get("torch._meta_registrations", None) is object
|
||||
|
||||
|
||||
|
|
@ -131,7 +141,7 @@ assert __all__ == sorted(__all__)
|
|||
|
||||
if sys.platform == "win32":
|
||||
|
||||
def _load_dll_libraries():
|
||||
def _load_dll_libraries() -> None:
|
||||
import sysconfig
|
||||
|
||||
from torch.version import cuda as cuda_version
|
||||
|
|
@ -246,7 +256,7 @@ if sys.platform == "win32":
|
|||
del _load_dll_libraries
|
||||
|
||||
|
||||
def _preload_cuda_deps(lib_folder, lib_name):
|
||||
def _preload_cuda_deps(lib_folder: str, lib_name: str) -> None:
|
||||
"""Preloads cuda deps if they could not be found otherwise."""
|
||||
# Should only be called on Linux if default path resolution have failed
|
||||
assert platform.system() == "Linux", "Should only be called on Linux"
|
||||
|
|
@ -284,7 +294,7 @@ def _load_global_deps() -> None:
|
|||
except OSError as err:
|
||||
# Can only happen for wheel with cuda libs as PYPI deps
|
||||
# As PyTorch is not purelib, but nvidia-*-cu12 is
|
||||
cuda_libs: Dict[str, str] = {
|
||||
cuda_libs: _Dict[str, str] = {
|
||||
"cublas": "libcublas.so.*[0-9]",
|
||||
"cudnn": "libcudnn.so.*[0-9]",
|
||||
"cuda_nvrtc": "libnvrtc.so.*[0-9]",
|
||||
|
|
@ -391,14 +401,14 @@ class SymInt:
|
|||
|
||||
def __floordiv__(self, other):
|
||||
if isinstance(other, (builtins.float, SymFloat)):
|
||||
return torch.sym_float(math.floor(sym_float(self) / other))
|
||||
return sym_float(math.floor(sym_float(self) / other))
|
||||
if not isinstance(other, (builtins.int, SymInt)):
|
||||
return NotImplemented
|
||||
return self.__int_floordiv__(other)
|
||||
|
||||
def __rfloordiv__(self, other):
|
||||
if isinstance(other, (builtins.float, SymFloat)):
|
||||
return torch.sym_float(math.floor(other / sym_float(self)))
|
||||
return sym_float(math.floor(other / sym_float(self)))
|
||||
if not isinstance(other, (builtins.int, SymInt)):
|
||||
return NotImplemented
|
||||
return self.__rint_floordiv__(other)
|
||||
|
|
@ -528,12 +538,12 @@ class SymFloat:
|
|||
def __floordiv__(self, other):
|
||||
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
|
||||
return NotImplemented
|
||||
return torch.sym_float(math.floor(self / sym_float(other)))
|
||||
return sym_float(math.floor(self / sym_float(other)))
|
||||
|
||||
def __rfloordiv__(self, other):
|
||||
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
|
||||
return NotImplemented
|
||||
return torch.sym_float(math.floor(sym_float(other) / self))
|
||||
return sym_float(math.floor(sym_float(other) / self))
|
||||
|
||||
def __bool__(self):
|
||||
return self.node.bool_()
|
||||
|
|
@ -858,7 +868,7 @@ if not TYPE_CHECKING:
|
|||
__name, __candidate = "", None
|
||||
for __name in dir(_C):
|
||||
__candidate = getattr(_C, __name)
|
||||
if type(__candidate) is type(_C):
|
||||
if inspect.ismodule(__candidate):
|
||||
# submodule
|
||||
sys.modules.setdefault(f"{__name__}._C.{__name}", __candidate)
|
||||
|
||||
|
|
@ -870,44 +880,42 @@ if not TYPE_CHECKING:
|
|||
################################################################################
|
||||
|
||||
|
||||
def typename(o):
|
||||
def typename(obj: _Any, /) -> str:
|
||||
"""
|
||||
String representation of the type of an object.
|
||||
|
||||
This function returns a fully qualified string representation of an object's type.
|
||||
Args:
|
||||
o (Object): The object whose type to represent
|
||||
obj (object): The object whose type to represent
|
||||
Returns:
|
||||
str: the type of the object `o`
|
||||
Example:
|
||||
>>> x = torch.tensor([1,2,3])
|
||||
>>> x = torch.tensor([1, 2, 3])
|
||||
>>> torch.typename(x)
|
||||
'torch.LongTensor'
|
||||
>>> torch.typename(torch.nn.Parameter)
|
||||
'torch.nn.parameter.Parameter'
|
||||
"""
|
||||
if isinstance(o, torch.Tensor):
|
||||
return o.type()
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj.type()
|
||||
|
||||
module = ""
|
||||
class_name = ""
|
||||
if (
|
||||
hasattr(o, "__module__")
|
||||
and o.__module__ != "builtins"
|
||||
and o.__module__ != "__builtin__"
|
||||
and o.__module__ is not None
|
||||
):
|
||||
module = o.__module__ + "."
|
||||
module = getattr(obj, "__module__", "") or ""
|
||||
qualname = ""
|
||||
|
||||
if hasattr(o, "__qualname__"):
|
||||
class_name = o.__qualname__
|
||||
elif hasattr(o, "__name__"):
|
||||
class_name = o.__name__
|
||||
if hasattr(obj, "__qualname__"):
|
||||
qualname = obj.__qualname__
|
||||
elif hasattr(obj, "__name__"):
|
||||
qualname = obj.__name__
|
||||
else:
|
||||
class_name = o.__class__.__name__
|
||||
module = obj.__class__.__module__ or ""
|
||||
qualname = obj.__class__.__qualname__
|
||||
|
||||
return module + class_name
|
||||
if module in {"", "builtins"}:
|
||||
return qualname
|
||||
return f"{module}.{qualname}"
|
||||
|
||||
|
||||
def is_tensor(obj):
|
||||
def is_tensor(obj: _Any, /) -> _TypeGuard["torch.Tensor"]:
|
||||
r"""Returns True if `obj` is a PyTorch tensor.
|
||||
|
||||
Note that this function is simply doing ``isinstance(obj, Tensor)``.
|
||||
|
|
@ -916,7 +924,7 @@ def is_tensor(obj):
|
|||
``is_tensor``.
|
||||
|
||||
Args:
|
||||
obj (Object): Object to test
|
||||
obj (object): Object to test
|
||||
Example::
|
||||
|
||||
>>> x = torch.tensor([1, 2, 3])
|
||||
|
|
@ -927,7 +935,7 @@ def is_tensor(obj):
|
|||
return isinstance(obj, torch.Tensor)
|
||||
|
||||
|
||||
def is_storage(obj):
|
||||
def is_storage(obj: _Any, /) -> _TypeGuard[_Union["TypedStorage", "UntypedStorage"]]:
|
||||
r"""Returns True if `obj` is a PyTorch storage object.
|
||||
|
||||
Args:
|
||||
|
|
@ -942,6 +950,7 @@ _GLOBAL_DEVICE_CONTEXT = threading.local()
|
|||
def get_default_device() -> "torch.device":
|
||||
r"""Gets the default ``torch.Tensor`` to be allocated on ``device``"""
|
||||
global _GLOBAL_DEVICE_CONTEXT
|
||||
|
||||
if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"):
|
||||
device = _GLOBAL_DEVICE_CONTEXT.device_context.device
|
||||
if device.index is not None:
|
||||
|
|
@ -954,7 +963,9 @@ def get_default_device() -> "torch.device":
|
|||
return torch.device("cpu")
|
||||
|
||||
|
||||
def set_default_device(device):
|
||||
def set_default_device(
|
||||
device: _Optional[_Union["torch.device", str, builtins.int]],
|
||||
) -> None:
|
||||
"""Sets the default ``torch.Tensor`` to be allocated on ``device``. This
|
||||
does not affect factory function calls which are called with an explicit
|
||||
``device`` argument. Factory calls will be performed as if they
|
||||
|
|
@ -1016,7 +1027,7 @@ def set_default_device(device):
|
|||
_GLOBAL_DEVICE_CONTEXT.device_context = device_context
|
||||
|
||||
|
||||
def set_default_tensor_type(t):
|
||||
def set_default_tensor_type(t: _Union[_Type["torch.Tensor"], str], /) -> None:
|
||||
r"""
|
||||
.. warning::
|
||||
|
||||
|
|
@ -1047,7 +1058,7 @@ def set_default_tensor_type(t):
|
|||
_C._set_default_tensor_type(t)
|
||||
|
||||
|
||||
def set_default_dtype(d):
|
||||
def set_default_dtype(d: "torch.dtype", /) -> None:
|
||||
r"""
|
||||
|
||||
Sets the default floating point dtype to :attr:`d`. Supports floating point dtype
|
||||
|
|
@ -1257,7 +1268,7 @@ def is_deterministic_algorithms_warn_only_enabled() -> builtins.bool:
|
|||
return _C._get_deterministic_algorithms_warn_only()
|
||||
|
||||
|
||||
def set_deterministic_debug_mode(debug_mode: Union[builtins.int, str]) -> None:
|
||||
def set_deterministic_debug_mode(debug_mode: _Union[builtins.int, str]) -> None:
|
||||
r"""Sets the debug mode for deterministic operations.
|
||||
|
||||
.. note:: This is an alternative interface for
|
||||
|
|
@ -1316,7 +1327,7 @@ def get_deterministic_debug_mode() -> builtins.int:
|
|||
return 0
|
||||
|
||||
|
||||
def get_float32_matmul_precision() -> builtins.str:
|
||||
def get_float32_matmul_precision() -> str:
|
||||
r"""Returns the current value of float32 matrix multiplication precision. Refer to
|
||||
:func:`torch.set_float32_matmul_precision` documentation for more details.
|
||||
"""
|
||||
|
|
@ -1389,7 +1400,7 @@ def set_float32_matmul_precision(precision: str) -> None:
|
|||
_C._set_float32_matmul_precision(precision)
|
||||
|
||||
|
||||
def set_warn_always(b: builtins.bool) -> None:
|
||||
def set_warn_always(b: builtins.bool, /) -> None:
|
||||
r"""When this flag is False (default) then some PyTorch warnings may only
|
||||
appear once per process. This helps avoid excessive warning information.
|
||||
Setting it to True causes these warnings to always appear, which may be
|
||||
|
|
@ -1419,10 +1430,10 @@ def is_warn_always_enabled() -> builtins.bool:
|
|||
|
||||
def _check_with(
|
||||
error_type,
|
||||
cond: Union[builtins.bool, SymBool],
|
||||
message: Callable[[], str],
|
||||
cond: _Union[builtins.bool, SymBool],
|
||||
message: _Callable[[], str],
|
||||
): # noqa: F811
|
||||
if not isinstance(cond, (builtins.bool, torch.SymBool)):
|
||||
if not isinstance(cond, (builtins.bool, SymBool)):
|
||||
raise TypeError(f"cond must be a bool, but got {type(cond)}")
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import expect_true
|
||||
|
|
@ -1557,13 +1568,13 @@ def _check_not_implemented(cond, message=None): # noqa: F811
|
|||
|
||||
|
||||
def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811
|
||||
if not torch.is_tensor(cond):
|
||||
if not is_tensor(cond):
|
||||
raise TypeError(f"cond must be a tensor, but got {type(cond)}")
|
||||
|
||||
if not cond.dtype == torch.bool:
|
||||
raise TypeError(f"cond tensor must have dtype torch.bool, but got {cond.dtype}")
|
||||
|
||||
_check_with(error_type, cond._is_all_true().item(), message)
|
||||
_check_with(error_type, cond._is_all_true().item(), message) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# C++ equivalent: `TORCH_CHECK_TENSOR_ALL`
|
||||
|
|
@ -1614,10 +1625,9 @@ from torch.storage import (
|
|||
UntypedStorage,
|
||||
)
|
||||
|
||||
|
||||
# NOTE: New <type>Storage classes should never be added. When adding a new
|
||||
# dtype, use torch.storage.TypedStorage directly.
|
||||
|
||||
|
||||
class ByteStorage(_LegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
|
|
@ -1805,7 +1815,7 @@ class QUInt2x4Storage(_LegacyStorage):
|
|||
return torch.quint2x4
|
||||
|
||||
|
||||
_storage_classes = {
|
||||
_storage_classes: _Set[_Type[_Union[TypedStorage, UntypedStorage]]] = {
|
||||
UntypedStorage,
|
||||
DoubleStorage,
|
||||
FloatStorage,
|
||||
|
|
@ -1828,7 +1838,7 @@ _storage_classes = {
|
|||
}
|
||||
|
||||
# The _tensor_classes set is initialized by the call to initialize_python_bindings.
|
||||
_tensor_classes: Set[Type] = set()
|
||||
_tensor_classes: _Set[_Type["torch.Tensor"]] = set()
|
||||
|
||||
# If you edit these imports, please update torch/__init__.py.in as well
|
||||
from torch import amp as amp, random as random, serialization as serialization
|
||||
|
|
@ -2067,7 +2077,7 @@ class _TorchCompileInductorWrapper:
|
|||
compiler_name = "inductor"
|
||||
|
||||
def __init__(self, mode, options, dynamic):
|
||||
self.config: Dict[str, Any] = dict()
|
||||
self.config: _Dict[str, _Any] = dict()
|
||||
self.dynamic = dynamic
|
||||
self.apply_mode(mode)
|
||||
self.apply_options(options)
|
||||
|
|
@ -2091,7 +2101,7 @@ class _TorchCompileInductorWrapper:
|
|||
and self.dynamic == other.dynamic
|
||||
)
|
||||
|
||||
def apply_mode(self, mode: Optional[str]):
|
||||
def apply_mode(self, mode: _Optional[str]):
|
||||
if mode is None or mode == "default":
|
||||
pass
|
||||
elif mode in {"reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"}:
|
||||
|
|
@ -2103,13 +2113,13 @@ class _TorchCompileInductorWrapper:
|
|||
f"Unrecognized mode={mode}, should be one of: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs"
|
||||
)
|
||||
|
||||
def apply_options(self, options: Optional[Dict[str, Any]]):
|
||||
def apply_options(self, options: _Optional[_Dict[str, _Any]]):
|
||||
if not options:
|
||||
return
|
||||
|
||||
from torch._inductor import config
|
||||
|
||||
current_config: Dict[str, Any] = config.shallow_copy_dict()
|
||||
current_config: _Dict[str, _Any] = config.shallow_copy_dict()
|
||||
|
||||
for key, val in options.items():
|
||||
attr_name = key.replace("-", "_")
|
||||
|
|
@ -2181,15 +2191,15 @@ class _TorchCompileWrapper:
|
|||
|
||||
|
||||
def compile(
|
||||
model: Optional[Callable] = None,
|
||||
model: _Optional[_Callable] = None,
|
||||
*,
|
||||
fullgraph: builtins.bool = False,
|
||||
dynamic: Optional[builtins.bool] = None,
|
||||
backend: Union[str, Callable] = "inductor",
|
||||
mode: Union[str, None] = None,
|
||||
options: Optional[Dict[str, Union[str, builtins.int, builtins.bool]]] = None,
|
||||
dynamic: _Optional[builtins.bool] = None,
|
||||
backend: _Union[str, _Callable] = "inductor",
|
||||
mode: _Union[str, None] = None,
|
||||
options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
||||
disable: builtins.bool = False,
|
||||
) -> Callable:
|
||||
) -> _Callable:
|
||||
"""
|
||||
Optimizes given model/function using TorchDynamo and specified backend.
|
||||
If you are compiling an :class:`torch.nn.Module`, you can also use :meth:`torch.nn.Module.compile`
|
||||
|
|
@ -2281,7 +2291,7 @@ def compile(
|
|||
# Decorator mode
|
||||
if model is None:
|
||||
|
||||
def fn(model: Callable):
|
||||
def fn(model: _Callable):
|
||||
if model is None:
|
||||
raise RuntimeError("Model can't be None")
|
||||
return compile(
|
||||
|
|
@ -2315,11 +2325,6 @@ def compile(
|
|||
)(model)
|
||||
|
||||
|
||||
from torch import export as export
|
||||
|
||||
from torch._higher_order_ops import cond, while_loop
|
||||
|
||||
|
||||
def _register_device_module(device_type, module):
|
||||
r"""Register an external runtime module of the specific :attr:`device_type`
|
||||
supported by torch.
|
||||
|
|
@ -2340,8 +2345,14 @@ def _register_device_module(device_type, module):
|
|||
sys.modules[torch_module_name] = module
|
||||
|
||||
|
||||
# expose return_types
|
||||
from torch import library as library, return_types as return_types
|
||||
from torch import (
|
||||
export as export,
|
||||
func as func,
|
||||
library as library,
|
||||
return_types as return_types,
|
||||
)
|
||||
from torch._higher_order_ops import cond as cond, while_loop as while_loop
|
||||
from torch.func import vmap as vmap
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
from torch import _meta_registrations
|
||||
|
|
@ -2355,10 +2366,6 @@ if "TORCH_CUDA_SANITIZER" in os.environ:
|
|||
# Populate magic methods on SymInt and SymFloat
|
||||
import torch.fx.experimental.sym_node
|
||||
|
||||
from torch import func as func
|
||||
from torch.func import vmap as vmap
|
||||
|
||||
|
||||
# Register MPS specific decomps
|
||||
torch.backends.mps._init()
|
||||
|
||||
|
|
@ -2367,7 +2374,7 @@ if not _running_with_deploy():
|
|||
|
||||
class _TritonLibrary:
|
||||
lib = torch.library.Library("triton", "DEF")
|
||||
ops_table: Dict[Tuple[str, str], Callable] = {}
|
||||
ops_table: _Dict[_Tuple[str, str], _Callable] = {}
|
||||
|
||||
@classmethod
|
||||
def registerOp(cls, op_key, full_schema, op_impl, dispatch_key):
|
||||
|
|
@ -2421,7 +2428,7 @@ else:
|
|||
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
||||
|
||||
|
||||
def get_device_module(device: Optional[Union[torch.device, str]] = None):
|
||||
def get_device_module(device: _Optional[_Union[torch.device, str]] = None):
|
||||
"""
|
||||
Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...).
|
||||
If no device is given, return the module for the current accelerator or CPU if none is present.
|
||||
|
|
@ -2447,8 +2454,8 @@ def get_device_module(device: Optional[Union[torch.device, str]] = None):
|
|||
|
||||
def _constrain_as_size(
|
||||
symbol,
|
||||
min: Optional[builtins.int] = None,
|
||||
max: Optional[builtins.int] = None,
|
||||
min: _Optional[builtins.int] = None,
|
||||
max: _Optional[builtins.int] = None,
|
||||
):
|
||||
"""
|
||||
This indicates that a given int is size-like, and can be used in any context where a size is expected.
|
||||
|
|
|
|||
|
|
@ -1,3 +1,14 @@
|
|||
from .cond import cond
|
||||
from .while_loop import while_loop
|
||||
from .flex_attention import flex_attention, flex_attention_backward
|
||||
from torch._higher_order_ops.cond import cond
|
||||
from torch._higher_order_ops.flex_attention import (
|
||||
flex_attention,
|
||||
flex_attention_backward,
|
||||
)
|
||||
from torch._higher_order_ops.while_loop import while_loop
|
||||
|
||||
|
||||
__all__ = [
|
||||
"cond",
|
||||
"while_loop",
|
||||
"flex_attention",
|
||||
"flex_attention_backward",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -4,12 +4,9 @@ import itertools
|
|||
from typing import Callable, List
|
||||
|
||||
import torch
|
||||
|
||||
import torch._prims_common as utils
|
||||
import torch._subclasses.functional_tensor
|
||||
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
from torch._C import DispatchKey
|
||||
from torch._C._functorch import _add_batch_dim, get_unwrapped, maybe_get_bdim
|
||||
from torch._higher_order_ops.utils import (
|
||||
|
|
@ -18,7 +15,6 @@ from torch._higher_order_ops.utils import (
|
|||
reenter_make_fx,
|
||||
unique_graph_id,
|
||||
)
|
||||
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
|
|
@ -27,6 +23,7 @@ from torch.fx.experimental.proxy_tensor import (
|
|||
track_tensor_tree,
|
||||
)
|
||||
|
||||
|
||||
aten = torch._ops.ops.aten
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,7 @@ import contextlib
|
|||
|
||||
import torch
|
||||
import torch._subclasses.functional_tensor
|
||||
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
from torch._C import DispatchKey
|
||||
from torch._C._functorch import (
|
||||
_add_batch_dim,
|
||||
|
|
@ -15,7 +13,6 @@ from torch._C._functorch import (
|
|||
)
|
||||
from torch._functorch.utils import exposed_in
|
||||
from torch._guards import detect_fake_mode
|
||||
|
||||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
|
|
@ -25,7 +22,6 @@ from torch._higher_order_ops.utils import (
|
|||
unique_graph_id,
|
||||
UnsupportedAliasMutationException,
|
||||
)
|
||||
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, Tuple, Union
|
|||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey
|
||||
from torch._higher_order_ops.torchbind import call_torchbind
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
|
|
@ -12,7 +13,6 @@ from torch.fx.experimental.proxy_tensor import (
|
|||
ProxyTorchDispatchMode,
|
||||
track_tensor_tree,
|
||||
)
|
||||
from .torchbind import call_torchbind
|
||||
|
||||
|
||||
class _EffectType(Enum):
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ from torch.fx.experimental.proxy_tensor import (
|
|||
track_tensor_tree,
|
||||
)
|
||||
from torch.fx.graph_module import GraphModule
|
||||
|
||||
from torch.overrides import TorchFunctionMode
|
||||
|
||||
|
||||
|
|
@ -288,7 +287,6 @@ def create_fw_bw_graph(score_mod, index_values, other_buffers):
|
|||
from torch._dispatch.python import suspend_functionalization
|
||||
from torch._functorch.aot_autograd import AOTConfig, create_joint
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
|
||||
from torch._subclasses.functional_tensor import disable_functional_mode
|
||||
from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import torch.utils._pytree as pytree
|
|||
from torch._C import DispatchKey
|
||||
from torch._dispatch.python import suspend_functionalization
|
||||
from torch._functorch.aot_autograd import AOTConfig, create_joint, from_fun
|
||||
|
||||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
|
|
|
|||
|
|
@ -2,17 +2,18 @@
|
|||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey
|
||||
from torch._higher_order_ops.utils import autograd_not_implemented
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
disable_proxy_modes_tracing,
|
||||
maybe_handle_decomp,
|
||||
ProxyTorchDispatchMode,
|
||||
track_tensor_tree,
|
||||
maybe_handle_decomp,
|
||||
)
|
||||
from torch._C import DispatchKey
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
|
||||
from torch._higher_order_ops.utils import autograd_not_implemented
|
||||
|
||||
|
||||
# TODO to figure out a more generic approach
|
||||
ALLOWABLE_OPS = [
|
||||
|
|
@ -43,7 +44,6 @@ class OutDtypeOperator(HigherOrderOperator):
|
|||
3. Cast the output to `out_dtype`
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("out_dtype")
|
||||
# TODO(ydwu4): Subclassing HigherOrderOperator causes __module__ to
|
||||
|
|
@ -55,10 +55,12 @@ class OutDtypeOperator(HigherOrderOperator):
|
|||
if not isinstance(op, torch._ops.OpOverload):
|
||||
raise ValueError("out_dtype's first argument must be an OpOverload")
|
||||
if op._schema.is_mutable:
|
||||
raise ValueError("out_dtype's first argument needs to be a functional operator")
|
||||
raise ValueError(
|
||||
"out_dtype's first argument needs to be a functional operator"
|
||||
)
|
||||
if not (
|
||||
len(op._schema.returns) == 1 and
|
||||
isinstance(op._schema.returns[0].type, torch.TensorType)
|
||||
len(op._schema.returns) == 1
|
||||
and isinstance(op._schema.returns[0].type, torch.TensorType)
|
||||
):
|
||||
raise ValueError(
|
||||
"out_dtype's can only apply to ops that return a single tensor"
|
||||
|
|
@ -77,6 +79,7 @@ class OutDtypeOperator(HigherOrderOperator):
|
|||
|
||||
out_dtype = OutDtypeOperator()
|
||||
|
||||
|
||||
def trace_out_dtype(proxy_mode, func_overload, op, output_dtype, *args):
|
||||
# NB: Long-term we should put the decomposition logic into
|
||||
# ProxyTorchDispatchMode so that people do not need to call maybe_handle_decomp
|
||||
|
|
@ -99,11 +102,7 @@ def trace_out_dtype(proxy_mode, func_overload, op, output_dtype, *args):
|
|||
|
||||
|
||||
@out_dtype.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def out_dtype_dense(
|
||||
op: torch._ops.OpOverload,
|
||||
output_dtype: torch.dtype,
|
||||
*args
|
||||
):
|
||||
def out_dtype_dense(op: torch._ops.OpOverload, output_dtype: torch.dtype, *args):
|
||||
if is_int_mm(op, output_dtype, args):
|
||||
return torch._int_mm(*args)
|
||||
return out_dtype_fallback(op, output_dtype, *args)
|
||||
|
|
@ -111,13 +110,13 @@ def out_dtype_dense(
|
|||
|
||||
def is_int_mm(op, output_dtype, args):
|
||||
return (
|
||||
op == torch.ops.aten.mm.default and
|
||||
output_dtype == torch.int32 and
|
||||
len(args) == 2 and
|
||||
args[0].dtype == torch.int8 and
|
||||
args[1].dtype == torch.int8 and
|
||||
args[0].is_cuda and
|
||||
args[1].is_cuda
|
||||
op == torch.ops.aten.mm.default
|
||||
and output_dtype == torch.int32
|
||||
and len(args) == 2
|
||||
and args[0].dtype == torch.int8
|
||||
and args[1].dtype == torch.int8
|
||||
and args[0].is_cuda
|
||||
and args[1].is_cuda
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -135,7 +134,9 @@ def out_dtype_fallback(op, output_dtype, *args):
|
|||
return res
|
||||
|
||||
|
||||
out_dtype.py_impl(DispatchKey.Autograd)(autograd_not_implemented(out_dtype, deferred_error=True))
|
||||
out_dtype.py_impl(DispatchKey.Autograd)(
|
||||
autograd_not_implemented(out_dtype, deferred_error=True)
|
||||
)
|
||||
|
||||
|
||||
@out_dtype.py_impl(ProxyTorchDispatchMode)
|
||||
|
|
@ -143,7 +144,7 @@ def out_dtype_proxy(
|
|||
mode: ProxyTorchDispatchMode,
|
||||
op: torch._ops.OpOverload,
|
||||
output_dtype: torch.dtype,
|
||||
*args
|
||||
*args,
|
||||
):
|
||||
if mode.enable_tracing:
|
||||
return trace_out_dtype(mode, out_dtype, op, output_dtype, *args)
|
||||
|
|
@ -156,7 +157,7 @@ def out_dtype_fake_tensor_mode(
|
|||
mode: FakeTensorMode,
|
||||
op: torch._ops.OpOverload,
|
||||
output_dtype: torch.dtype,
|
||||
*args
|
||||
*args,
|
||||
):
|
||||
with mode:
|
||||
return out_dtype_dense(op, output_dtype, *args)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,9 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
import torch._subclasses.functional_tensor
|
||||
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
from torch._C import DispatchKey
|
||||
from torch._functorch.utils import exposed_in
|
||||
|
||||
from torch._higher_order_ops.utils import _set_compilation_env, autograd_not_implemented
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_ten
|
|||
from torch.fx.node import has_side_effect
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# The call_torchbind operator represents a method invocation on a torchbind
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from torch.fx.experimental.proxy_tensor import (
|
|||
track_tensor_tree,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger("torch._dynamo")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,7 @@ from typing import Callable, Tuple, Union
|
|||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
from torch._C import DispatchKey
|
||||
|
||||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
|
|
|
|||
|
|
@ -4,15 +4,16 @@ import itertools
|
|||
import logging
|
||||
|
||||
import torch
|
||||
import torch._dynamo.config
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
import torch._dynamo.config
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
uid = itertools.count(1)
|
||||
|
||||
|
||||
# Used for testing the HigherOrderOperator mechanism
|
||||
class Wrap(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
|
|
@ -31,8 +32,10 @@ class Wrap(HigherOrderOperator):
|
|||
|
||||
return wrapper()
|
||||
|
||||
|
||||
wrap = Wrap()
|
||||
|
||||
|
||||
class WrapWithSetGradEnabled(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("wrap_with_set_grad_enabled")
|
||||
|
|
@ -47,10 +50,13 @@ class WrapWithSetGradEnabled(HigherOrderOperator):
|
|||
def wrapper():
|
||||
with torch.set_grad_enabled(enable_grad):
|
||||
return wrapped_func(*args, **kwargs)
|
||||
|
||||
return wrapper()
|
||||
|
||||
|
||||
wrap_with_set_grad_enabled = WrapWithSetGradEnabled()
|
||||
|
||||
|
||||
class WrapActivationCheckpoint(HigherOrderOperator):
|
||||
"""
|
||||
This operator is used to wrap torch.utils.checkpoint. This avoids
|
||||
|
|
@ -68,6 +74,7 @@ class WrapActivationCheckpoint(HigherOrderOperator):
|
|||
that duplication/recomputation is done as a compiler pass in the
|
||||
partitioners. See TagActivationCheckpoint for more information.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("wrap_activation_checkpoint")
|
||||
|
||||
|
|
@ -77,14 +84,17 @@ class WrapActivationCheckpoint(HigherOrderOperator):
|
|||
# version of checkpointing.
|
||||
import torch.fx.traceback as fx_traceback
|
||||
from torch.fx import Interpreter
|
||||
|
||||
kwargs["use_reentrant"] = False
|
||||
kwargs["preserve_rng_state"] = False
|
||||
# Using interpreter allows preservation of metadata through torch.compile stack.
|
||||
with fx_traceback.preserve_node_meta():
|
||||
return checkpoint(Interpreter(function).run, *args, **kwargs)
|
||||
|
||||
|
||||
wrap_activation_checkpoint = WrapActivationCheckpoint()
|
||||
|
||||
|
||||
class TagActivationCheckpoint(HigherOrderOperator):
|
||||
"""
|
||||
This operator is supposed to be used only with torch.compile stack. This
|
||||
|
|
@ -136,8 +146,12 @@ class TagActivationCheckpoint(HigherOrderOperator):
|
|||
# `preserve_rng_state` is not a regular kwarg
|
||||
checkpoint_keys.add("preserve_rng_state")
|
||||
|
||||
checkpoint_kwargs = {name: kwargs[name] for name in kwargs.keys() if name in checkpoint_keys}
|
||||
gmod_kwargs = {name: kwargs[name] for name in kwargs.keys() if name not in checkpoint_keys}
|
||||
checkpoint_kwargs = {
|
||||
name: kwargs[name] for name in kwargs.keys() if name in checkpoint_keys
|
||||
}
|
||||
gmod_kwargs = {
|
||||
name: kwargs[name] for name in kwargs.keys() if name not in checkpoint_keys
|
||||
}
|
||||
return checkpoint_kwargs, gmod_kwargs
|
||||
|
||||
def tag_nodes(self, gmod):
|
||||
|
|
@ -150,13 +164,17 @@ class TagActivationCheckpoint(HigherOrderOperator):
|
|||
def __call__(self, gmod, *args, **kwargs):
|
||||
import torch.fx.traceback as fx_traceback
|
||||
from torch.fx import Interpreter
|
||||
|
||||
if "_checkpoint_context_fn" in gmod.meta:
|
||||
assert torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint, \
|
||||
"Passing context_fn to torch.utils.checkpoint is currently not supported under torch.compile"
|
||||
log.warning("""
|
||||
assert (
|
||||
torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint
|
||||
), "Passing context_fn to torch.utils.checkpoint is currently not supported under torch.compile"
|
||||
log.warning(
|
||||
"""
|
||||
Detected that context_fn is passed to torch.utils.checkpoint under torch.compile.
|
||||
Please make sure the checkpointed region does not contain in-place ops (e.g. torch.relu_).
|
||||
""")
|
||||
"""
|
||||
)
|
||||
# use_reentrant is set to False because this op is going to be traced.
|
||||
# And we ensure that AOT Autograd traces through the non reentrant
|
||||
# version of checkpointing.
|
||||
|
|
@ -183,4 +201,5 @@ Please make sure the checkpointed region does not contain in-place ops (e.g. tor
|
|||
with fx_traceback.preserve_node_meta():
|
||||
return Interpreter(gmod).run(*args)
|
||||
|
||||
|
||||
tag_activation_checkpoint = TagActivationCheckpoint()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
|
||||
import warnings
|
||||
from typing import Any
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
import torch
|
||||
from torch.overrides import get_default_nowrap_functions
|
||||
|
|
@ -13,7 +15,7 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
def is_masked_tensor(a):
|
||||
def is_masked_tensor(obj: Any, /) -> TypeGuard["MaskedTensor"]:
|
||||
r"""Returns True if the input is a MaskedTensor, else False
|
||||
|
||||
Args:
|
||||
|
|
@ -29,7 +31,7 @@ def is_masked_tensor(a):
|
|||
>>> is_masked_tensor(mt)
|
||||
True
|
||||
"""
|
||||
return isinstance(a, MaskedTensor)
|
||||
return isinstance(obj, MaskedTensor)
|
||||
|
||||
|
||||
def _tensors_match(a, b, exact=True, rtol=1e-05, atol=1e-08):
|
||||
|
|
@ -147,13 +149,14 @@ class MaskedTensor(torch.Tensor):
|
|||
if is_masked_tensor(mask) or not torch.is_tensor(mask):
|
||||
raise TypeError("mask must be a Tensor")
|
||||
# Use a Tensor that of the give size for the wrapper.
|
||||
kwargs = {}
|
||||
kwargs["device"] = data.device
|
||||
kwargs["dtype"] = data.dtype
|
||||
kwargs["layout"] = data.layout
|
||||
kwargs["requires_grad"] = requires_grad
|
||||
kwargs["dispatch_sizes_strides_policy"] = "strides"
|
||||
kwargs["dispatch_layout"] = True
|
||||
kwargs = {
|
||||
"device": data.device,
|
||||
"dtype": data.dtype,
|
||||
"layout": data.layout,
|
||||
"requires_grad": requires_grad,
|
||||
"dispatch_sizes_strides_policy": "strides",
|
||||
"dispatch_layout": True,
|
||||
}
|
||||
warnings.warn(
|
||||
(
|
||||
"The PyTorch API of MaskedTensors is in prototype stage "
|
||||
|
|
@ -162,12 +165,14 @@ class MaskedTensor(torch.Tensor):
|
|||
"module for further information about the project."
|
||||
),
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if data.requires_grad:
|
||||
warnings.warn(
|
||||
"It is not recommended to create a MaskedTensor with a tensor that requires_grad. "
|
||||
"To avoid this, you can use data.clone().detach()",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) # type: ignore[attr-defined]
|
||||
|
||||
|
|
|
|||
|
|
@ -1598,7 +1598,9 @@ class DistributedDataParallel(Module, Joinable):
|
|||
treespec,
|
||||
output_is_rref,
|
||||
) = _tree_flatten_with_rref(output)
|
||||
output_placeholders = [None for _ in range(len(output_tensor_list))]
|
||||
output_placeholders: List[Optional[torch.Tensor]] = [
|
||||
None for _ in range(len(output_tensor_list))
|
||||
]
|
||||
# Do not touch tensors that have no grad_fn, which can cause issues
|
||||
# such as https://github.com/pytorch/pytorch/issues/60733
|
||||
for i, output in enumerate(output_tensor_list):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,17 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import builtins
|
||||
|
||||
# In some cases, these basic types are shadowed by corresponding
|
||||
# top-level values. The underscore variants let us refer to these
|
||||
# types. See https://github.com/python/mypy/issues/4146 for why these
|
||||
# workarounds is necessary
|
||||
from builtins import ( # noqa: F401
|
||||
bool as _bool,
|
||||
bytes as _bytes,
|
||||
complex as _complex,
|
||||
float as _float,
|
||||
int as _int,
|
||||
str as _str,
|
||||
)
|
||||
from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
|
@ -20,71 +32,75 @@ _TensorOrTensorsOrGradEdge = Union[
|
|||
Sequence["GradientEdge"],
|
||||
]
|
||||
|
||||
# In some cases, these basic types are shadowed by corresponding
|
||||
# top-level values. The underscore variants let us refer to these
|
||||
# types. See https://github.com/python/mypy/issues/4146 for why these
|
||||
# workarounds is necessary
|
||||
_int = builtins.int
|
||||
_float = builtins.float
|
||||
_bool = builtins.bool
|
||||
_complex = builtins.complex
|
||||
|
||||
_dtype = torch.dtype
|
||||
_device = torch.device
|
||||
_qscheme = torch.qscheme
|
||||
_size = Union[torch.Size, List[_int], Tuple[_int, ...]]
|
||||
_layout = torch.layout
|
||||
_dispatchkey = Union[str, torch._C.DispatchKey]
|
||||
_size = Union[torch.Size, List[_int], Tuple[_int, ...]]
|
||||
_dispatchkey = Union[_str, torch._C.DispatchKey]
|
||||
|
||||
# Meta-type for "numeric" things; matches our docs
|
||||
Number = Union[builtins.int, builtins.float, builtins.bool]
|
||||
Number = Union[_int, _float, _bool]
|
||||
|
||||
# Meta-type for "device-like" things. Not to be confused with 'device' (a
|
||||
# literal device object). This nomenclature is consistent with PythonArgParser.
|
||||
# None means use the default device (typically CPU)
|
||||
Device = Optional[Union[_device, str, _int]]
|
||||
Device = Optional[Union[_device, _str, _int]]
|
||||
del Optional
|
||||
|
||||
# Storage protocol implemented by ${Type}StorageBase classes
|
||||
|
||||
|
||||
class Storage:
|
||||
_cdata: int
|
||||
_cdata: _int
|
||||
device: torch.device
|
||||
dtype: torch.dtype
|
||||
_torch_load_uninitialized: bool
|
||||
_torch_load_uninitialized: _bool
|
||||
|
||||
def __deepcopy__(self, memo) -> "Storage": # type: ignore[empty-body]
|
||||
def __deepcopy__(self, memo: dict) -> "Storage": # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def _new_shared(self, int) -> "Storage": # type: ignore[empty-body]
|
||||
def _new_shared(self, size: _int) -> "Storage": # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def _write_file(
|
||||
self, f: Any, is_real_file: _bool, save_size: _bool, element_size: int
|
||||
self,
|
||||
f: Any,
|
||||
is_real_file: _bool,
|
||||
save_size: _bool,
|
||||
element_size: _int,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
def element_size(self) -> int: # type: ignore[empty-body]
|
||||
def element_size(self) -> _int: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def is_shared(self) -> bool: # type: ignore[empty-body]
|
||||
def is_shared(self) -> _bool: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def share_memory_(self) -> "Storage": # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def nbytes(self) -> int: # type: ignore[empty-body]
|
||||
def nbytes(self) -> _int: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def cpu(self) -> "Storage": # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def data_ptr(self) -> int: # type: ignore[empty-body]
|
||||
def data_ptr(self) -> _int: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def from_file(self, filename: str, shared: bool = False, nbytes: int = 0) -> "Storage": # type: ignore[empty-body]
|
||||
def from_file( # type: ignore[empty-body]
|
||||
self,
|
||||
filename: _str,
|
||||
shared: _bool = False,
|
||||
nbytes: _int = 0,
|
||||
) -> "Storage":
|
||||
...
|
||||
|
||||
def _new_with_file(self, f: Any, element_size: int) -> "Storage": # type: ignore[empty-body]
|
||||
def _new_with_file( # type: ignore[empty-body]
|
||||
self,
|
||||
f: Any,
|
||||
element_size: _int,
|
||||
) -> "Storage":
|
||||
...
|
||||
|
|
|
|||
|
|
@ -7,12 +7,11 @@ from torch._inductor.fx_passes import joint_graph
|
|||
|
||||
if __name__ == "__main__":
|
||||
# Start by deleting all the existing patterns.
|
||||
for file in os.listdir(pattern_matcher.SERIALIZED_PATTERN_PATH):
|
||||
if file in ("__init__.py", "__pycache__"):
|
||||
for path in pattern_matcher.SERIALIZED_PATTERN_PATH.iterdir():
|
||||
if path.name in {"__init__.py", "__pycache__"}:
|
||||
continue
|
||||
file = pattern_matcher.SERIALIZED_PATTERN_PATH / file
|
||||
if file.is_file():
|
||||
file.unlink()
|
||||
if path.is_file():
|
||||
path.unlink()
|
||||
|
||||
# Now have joint_graph load all known patterns and tell the pattern matcher
|
||||
# to serialize the patterns as it goes.
|
||||
|
|
|
|||
Loading…
Reference in a new issue