diff --git a/.github/scripts/pytest_caching_utils.py b/.github/scripts/pytest_caching_utils.py index 0141bfd8da6..0cfb4e823f6 100644 --- a/.github/scripts/pytest_caching_utils.py +++ b/.github/scripts/pytest_caching_utils.py @@ -30,6 +30,8 @@ UNZIPPED_CACHES = "unzipped-caches" # Since the pr identifier can be based on include user defined text (like a branch name) # we hash it to sanitize the input and avoid corner cases class PRIdentifier(str): + __slots__ = () + def __new__(cls, value: str) -> "PRIdentifier": md5 = hashlib.md5(value.encode("utf-8")).hexdigest() return super().__new__(cls, md5) diff --git a/functorch/dim/reference.py b/functorch/dim/reference.py index 6453a441b94..01992cc5c12 100644 --- a/functorch/dim/reference.py +++ b/functorch/dim/reference.py @@ -139,6 +139,8 @@ def seq(a, b): class isin: + __slots__ = () + def __contains__(self, item): for x in self: if seq(item, x): @@ -153,11 +155,11 @@ class isin: class llist(isin, list): - pass + __slots__ = () class ltuple(isin, tuple): - pass + __slots__ = () empty_dict = {} diff --git a/pyproject.toml b/pyproject.toml index 1f9844eabbe..29dd32040fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,6 +148,7 @@ select = [ "RUF019", # unnecessary-key-check "RUF024", # from keys mutable "RUF026", # default factory kwarg + "SLOT", "TCH", "TRY002", # ban vanilla raise (todo fix NOQAs) "TRY203", diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 02f3b8452f4..f34fe06751d 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -4448,7 +4448,7 @@ class DefaultsTests(torch._dynamo.test_case.TestCase): self.assertEqual(fn(inputs, x), opt_fn(inputs, x)) def test_udf_tuple(self): - class MyTuple(tuple): + class MyTuple(tuple): # noqa: SLOT001 def len_mulitply_2(self): return len(self) * 2 @@ -4475,7 +4475,7 @@ class DefaultsTests(torch._dynamo.test_case.TestCase): self.assertTrue(res_tup.checked) def test_udf_tuple_reconstruction(self): - class MyTuple(tuple): + class MyTuple(tuple): # noqa: SLOT001 pass def fn(x, klass): diff --git a/test/torch_np/numpy_tests/core/test_indexing.py b/test/torch_np/numpy_tests/core/test_indexing.py index 55d7aa4675d..ed402bd8595 100644 --- a/test/torch_np/numpy_tests/core/test_indexing.py +++ b/test/torch_np/numpy_tests/core/test_indexing.py @@ -413,7 +413,7 @@ class TestIndexing(TestCase): # A tuple subclass should also be an nd-index class TupleSubclass(tuple): - pass + __slots__ = () index = ([1], [1]) index = TupleSubclass(index) diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py index 20507782dc2..c62f19e3440 100644 --- a/torch/_dynamo/polyfills/pytree.py +++ b/torch/_dynamo/polyfills/pytree.py @@ -105,6 +105,8 @@ if python_pytree._cxx_pytree_dynamo_traceable: __all__ += ["tree_leaves"] class _Asterisk(str): + __slots__ = () + def __new__(cls) -> Self: return super().__new__(cls, "*") diff --git a/torch/_export/serde/union.py b/torch/_export/serde/union.py index 006b809e1e5..ca8a87951ea 100644 --- a/torch/_export/serde/union.py +++ b/torch/_export/serde/union.py @@ -5,6 +5,7 @@ from dataclasses import fields class _UnionTag(str): + __slots__ = ("_cls",) _cls: Hashable @staticmethod diff --git a/torch/_utils.py b/torch/_utils.py index 7c645435f87..f227042803f 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -696,6 +696,8 @@ def render_call(fn, args, kwargs): class KeyErrorMessage(str): r"""str subclass that returns itself in repr""" + __slots__ = () + def __repr__(self): return self diff --git a/torch/ao/quantization/fx/_equalize.py b/torch/ao/quantization/fx/_equalize.py index 12310338514..77bc4e31d19 100644 --- a/torch/ao/quantization/fx/_equalize.py +++ b/torch/ao/quantization/fx/_equalize.py @@ -266,6 +266,8 @@ class EqualizationQConfig( weight=_WeightEqualizationObserver.with_args(dtype=torch.qint8)) """ + __slots__ = () + def __new__(cls, input_activation=torch.nn.Identity, weight=torch.nn.Identity): if isinstance(input_activation, nn.Module) or isinstance(weight, nn.Module): raise ValueError( diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index e7e9cfe9db4..246d74b601c 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -102,6 +102,8 @@ class QConfig(namedtuple("QConfig", ["activation", "weight"])): """ + __slots__ = () + def __new__(cls, activation, weight): # catch common mistakes if isinstance(activation, nn.Module) or isinstance(weight, nn.Module): @@ -133,6 +135,8 @@ class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])): my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8)) """ + __slots__ = () + def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity): # catch common mistakes if isinstance(weight, nn.Module): diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index ff89dd79f2d..2f4b7967b52 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -231,7 +231,8 @@ def supports_complex(reduceOp: ReduceOp) -> bool: return reduceOp not in denyList -class Backend(str): +# TODO refactor into enum/strenum +class Backend(str): # noqa: SLOT000 """ An enum-like class for backends. diff --git a/torch/fx/passes/infra/pass_base.py b/torch/fx/passes/infra/pass_base.py index acf78d2581b..957b8145f99 100644 --- a/torch/fx/passes/infra/pass_base.py +++ b/torch/fx/passes/infra/pass_base.py @@ -18,6 +18,8 @@ class PassResult(namedtuple("PassResult", ["graph_module", "modified"])): modified: A flag for if the pass has modified the graph module """ + __slots__ = () + def __new__(cls, graph_module, modified): return super().__new__(cls, graph_module, modified) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 76662e251a2..68310ba33a3 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -40,6 +40,8 @@ T = TypeVar("T", bound="Module") class _IncompatibleKeys( namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]), ): + __slots__ = () + def __repr__(self): if not self.missing_keys and not self.unexpected_keys: return "" diff --git a/torch/testing/_internal/common_dtype.py b/torch/testing/_internal/common_dtype.py index 5a340ea670d..774ce179f33 100644 --- a/torch/testing/_internal/common_dtype.py +++ b/torch/testing/_internal/common_dtype.py @@ -17,6 +17,8 @@ def _validate_dtypes(*dtypes): # class for tuples corresponding to a PyTorch dispatch macro class _dispatch_dtypes(tuple): + __slots__ = () + def __add__(self, other): assert isinstance(other, tuple) return _dispatch_dtypes(tuple.__add__(self, other)) diff --git a/torch/testing/_internal/jit_metaprogramming_utils.py b/torch/testing/_internal/jit_metaprogramming_utils.py index 2f0121af743..b3dbb95f4ba 100644 --- a/torch/testing/_internal/jit_metaprogramming_utils.py +++ b/torch/testing/_internal/jit_metaprogramming_utils.py @@ -33,7 +33,7 @@ def unpack_variables(args): return args class dont_convert(tuple): - pass + __slots__ = () non_differentiable = collections.namedtuple('non_differentiable', ['tensor']) diff --git a/torch/torch_version.py b/torch/torch_version.py index ca4042ed7f9..0496a1b564f 100644 --- a/torch/torch_version.py +++ b/torch/torch_version.py @@ -26,6 +26,8 @@ class TorchVersion(str): TorchVersion('1.10.0a') > '1.2.1' """ + __slots__ = () + # fully qualified type names here to appease mypy def _convert_to_version(self, inp: Any) -> Any: if isinstance(inp, Version):