From 7f65a208848205b38445423b7e2e93a2b4994e5e Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Tue, 4 Feb 2025 19:18:21 +0000 Subject: [PATCH] [BE]: Enable ruff SLOT checks (#146276) This enables a check that which a class which only inherits from immutable classes like str, tuple, and NamedTuple, also defined `__slots__` so they don't allocate memory unnecessarily. This also ensure contributors think about how they define their classes with subclass NamedTuples and str, of which we have many in our codebase Pull Request resolved: https://github.com/pytorch/pytorch/pull/146276 Approved by: https://github.com/aorenste --- .github/scripts/pytest_caching_utils.py | 2 ++ functorch/dim/reference.py | 6 ++++-- pyproject.toml | 1 + test/dynamo/test_functions.py | 4 ++-- test/torch_np/numpy_tests/core/test_indexing.py | 2 +- torch/_dynamo/polyfills/pytree.py | 2 ++ torch/_export/serde/union.py | 1 + torch/_utils.py | 2 ++ torch/ao/quantization/fx/_equalize.py | 2 ++ torch/ao/quantization/qconfig.py | 4 ++++ torch/distributed/distributed_c10d.py | 3 ++- torch/fx/passes/infra/pass_base.py | 2 ++ torch/nn/modules/module.py | 2 ++ torch/testing/_internal/common_dtype.py | 2 ++ torch/testing/_internal/jit_metaprogramming_utils.py | 2 +- torch/torch_version.py | 2 ++ 16 files changed, 32 insertions(+), 7 deletions(-) diff --git a/.github/scripts/pytest_caching_utils.py b/.github/scripts/pytest_caching_utils.py index 0141bfd8da6..0cfb4e823f6 100644 --- a/.github/scripts/pytest_caching_utils.py +++ b/.github/scripts/pytest_caching_utils.py @@ -30,6 +30,8 @@ UNZIPPED_CACHES = "unzipped-caches" # Since the pr identifier can be based on include user defined text (like a branch name) # we hash it to sanitize the input and avoid corner cases class PRIdentifier(str): + __slots__ = () + def __new__(cls, value: str) -> "PRIdentifier": md5 = hashlib.md5(value.encode("utf-8")).hexdigest() return super().__new__(cls, md5) diff --git a/functorch/dim/reference.py b/functorch/dim/reference.py index 6453a441b94..01992cc5c12 100644 --- a/functorch/dim/reference.py +++ b/functorch/dim/reference.py @@ -139,6 +139,8 @@ def seq(a, b): class isin: + __slots__ = () + def __contains__(self, item): for x in self: if seq(item, x): @@ -153,11 +155,11 @@ class isin: class llist(isin, list): - pass + __slots__ = () class ltuple(isin, tuple): - pass + __slots__ = () empty_dict = {} diff --git a/pyproject.toml b/pyproject.toml index 1f9844eabbe..29dd32040fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,6 +148,7 @@ select = [ "RUF019", # unnecessary-key-check "RUF024", # from keys mutable "RUF026", # default factory kwarg + "SLOT", "TCH", "TRY002", # ban vanilla raise (todo fix NOQAs) "TRY203", diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 02f3b8452f4..f34fe06751d 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -4448,7 +4448,7 @@ class DefaultsTests(torch._dynamo.test_case.TestCase): self.assertEqual(fn(inputs, x), opt_fn(inputs, x)) def test_udf_tuple(self): - class MyTuple(tuple): + class MyTuple(tuple): # noqa: SLOT001 def len_mulitply_2(self): return len(self) * 2 @@ -4475,7 +4475,7 @@ class DefaultsTests(torch._dynamo.test_case.TestCase): self.assertTrue(res_tup.checked) def test_udf_tuple_reconstruction(self): - class MyTuple(tuple): + class MyTuple(tuple): # noqa: SLOT001 pass def fn(x, klass): diff --git a/test/torch_np/numpy_tests/core/test_indexing.py b/test/torch_np/numpy_tests/core/test_indexing.py index 55d7aa4675d..ed402bd8595 100644 --- a/test/torch_np/numpy_tests/core/test_indexing.py +++ b/test/torch_np/numpy_tests/core/test_indexing.py @@ -413,7 +413,7 @@ class TestIndexing(TestCase): # A tuple subclass should also be an nd-index class TupleSubclass(tuple): - pass + __slots__ = () index = ([1], [1]) index = TupleSubclass(index) diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py index 20507782dc2..c62f19e3440 100644 --- a/torch/_dynamo/polyfills/pytree.py +++ b/torch/_dynamo/polyfills/pytree.py @@ -105,6 +105,8 @@ if python_pytree._cxx_pytree_dynamo_traceable: __all__ += ["tree_leaves"] class _Asterisk(str): + __slots__ = () + def __new__(cls) -> Self: return super().__new__(cls, "*") diff --git a/torch/_export/serde/union.py b/torch/_export/serde/union.py index 006b809e1e5..ca8a87951ea 100644 --- a/torch/_export/serde/union.py +++ b/torch/_export/serde/union.py @@ -5,6 +5,7 @@ from dataclasses import fields class _UnionTag(str): + __slots__ = ("_cls",) _cls: Hashable @staticmethod diff --git a/torch/_utils.py b/torch/_utils.py index 7c645435f87..f227042803f 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -696,6 +696,8 @@ def render_call(fn, args, kwargs): class KeyErrorMessage(str): r"""str subclass that returns itself in repr""" + __slots__ = () + def __repr__(self): return self diff --git a/torch/ao/quantization/fx/_equalize.py b/torch/ao/quantization/fx/_equalize.py index 12310338514..77bc4e31d19 100644 --- a/torch/ao/quantization/fx/_equalize.py +++ b/torch/ao/quantization/fx/_equalize.py @@ -266,6 +266,8 @@ class EqualizationQConfig( weight=_WeightEqualizationObserver.with_args(dtype=torch.qint8)) """ + __slots__ = () + def __new__(cls, input_activation=torch.nn.Identity, weight=torch.nn.Identity): if isinstance(input_activation, nn.Module) or isinstance(weight, nn.Module): raise ValueError( diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index e7e9cfe9db4..246d74b601c 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -102,6 +102,8 @@ class QConfig(namedtuple("QConfig", ["activation", "weight"])): """ + __slots__ = () + def __new__(cls, activation, weight): # catch common mistakes if isinstance(activation, nn.Module) or isinstance(weight, nn.Module): @@ -133,6 +135,8 @@ class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])): my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8)) """ + __slots__ = () + def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity): # catch common mistakes if isinstance(weight, nn.Module): diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index ff89dd79f2d..2f4b7967b52 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -231,7 +231,8 @@ def supports_complex(reduceOp: ReduceOp) -> bool: return reduceOp not in denyList -class Backend(str): +# TODO refactor into enum/strenum +class Backend(str): # noqa: SLOT000 """ An enum-like class for backends. diff --git a/torch/fx/passes/infra/pass_base.py b/torch/fx/passes/infra/pass_base.py index acf78d2581b..957b8145f99 100644 --- a/torch/fx/passes/infra/pass_base.py +++ b/torch/fx/passes/infra/pass_base.py @@ -18,6 +18,8 @@ class PassResult(namedtuple("PassResult", ["graph_module", "modified"])): modified: A flag for if the pass has modified the graph module """ + __slots__ = () + def __new__(cls, graph_module, modified): return super().__new__(cls, graph_module, modified) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 76662e251a2..68310ba33a3 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -40,6 +40,8 @@ T = TypeVar("T", bound="Module") class _IncompatibleKeys( namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]), ): + __slots__ = () + def __repr__(self): if not self.missing_keys and not self.unexpected_keys: return "" diff --git a/torch/testing/_internal/common_dtype.py b/torch/testing/_internal/common_dtype.py index 5a340ea670d..774ce179f33 100644 --- a/torch/testing/_internal/common_dtype.py +++ b/torch/testing/_internal/common_dtype.py @@ -17,6 +17,8 @@ def _validate_dtypes(*dtypes): # class for tuples corresponding to a PyTorch dispatch macro class _dispatch_dtypes(tuple): + __slots__ = () + def __add__(self, other): assert isinstance(other, tuple) return _dispatch_dtypes(tuple.__add__(self, other)) diff --git a/torch/testing/_internal/jit_metaprogramming_utils.py b/torch/testing/_internal/jit_metaprogramming_utils.py index 2f0121af743..b3dbb95f4ba 100644 --- a/torch/testing/_internal/jit_metaprogramming_utils.py +++ b/torch/testing/_internal/jit_metaprogramming_utils.py @@ -33,7 +33,7 @@ def unpack_variables(args): return args class dont_convert(tuple): - pass + __slots__ = () non_differentiable = collections.namedtuple('non_differentiable', ['tensor']) diff --git a/torch/torch_version.py b/torch/torch_version.py index ca4042ed7f9..0496a1b564f 100644 --- a/torch/torch_version.py +++ b/torch/torch_version.py @@ -26,6 +26,8 @@ class TorchVersion(str): TorchVersion('1.10.0a') > '1.2.1' """ + __slots__ = () + # fully qualified type names here to appease mypy def _convert_to_version(self, inp: Any) -> Any: if isinstance(inp, Version):