From 6d725e7d667dca8f46ebfde85f9eddac08c57715 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Mon, 18 Sep 2023 02:07:18 +0000 Subject: [PATCH] [BE]: enable ruff rules PLR1722 and PLW3301 (#109461) Enables two ruff rules derived from pylint: * PLR1722 replaces any exit() calls with sys.exit(). exit() is only designed to be used in repl contexts as may not always be imported by default. This always use the version in the sys module which is better * PLW3301 replaces nested min / max calls with simplified versions (ie. `min(a, min(b, c))` => `min(a, b. c)`). The new version is more idiomatic and more efficient. Pull Request resolved: https://github.com/pytorch/pytorch/pull/109461 Approved by: https://github.com/ezyang --- .ci/pytorch/win-test-helpers/run_python_nn_smoketests.py | 3 ++- .github/scripts/check_labels.py | 3 ++- benchmarks/dynamo/benchmarks.py | 3 ++- caffe2/python/models/download.py | 6 +++--- pyproject.toml | 2 ++ scripts/model_zoo/update-caffe2-models.py | 2 +- test/distributions/test_transforms.py | 2 +- tools/linter/adapters/actionlint_linter.py | 3 ++- tools/linter/adapters/bazel_linter.py | 3 ++- tools/linter/adapters/clangtidy_linter.py | 2 +- tools/linter/adapters/grep_linter.py | 2 +- tools/linter/adapters/lintrunner_version_linter.py | 3 ++- tools/linter/adapters/s3_init.py | 2 +- tools/linter/adapters/shellcheck_linter.py | 3 ++- tools/linter/clang_tidy/generate_build_files.py | 2 +- tools/test/test_heuristics.py | 2 +- tools/test/test_test_selections.py | 4 ++-- torch/backends/xeon/run_cpu.py | 4 ++-- torch/testing/_internal/common_methods_invocations.py | 2 +- torch/testing/_internal/common_utils.py | 4 ++-- torch/testing/_internal/opinfo/definitions/linalg.py | 2 +- 21 files changed, 34 insertions(+), 25 deletions(-) diff --git a/.ci/pytorch/win-test-helpers/run_python_nn_smoketests.py b/.ci/pytorch/win-test-helpers/run_python_nn_smoketests.py index 07b78b61511..5f45ad40056 100755 --- a/.ci/pytorch/win-test-helpers/run_python_nn_smoketests.py +++ b/.ci/pytorch/win-test-helpers/run_python_nn_smoketests.py @@ -2,6 +2,7 @@ import os import subprocess +import sys COMMON_TESTS = [ ( @@ -53,4 +54,4 @@ if __name__ == "__main__": print("Reruning with traceback enabled") print("Command:", command_string) subprocess.run(command_args, check=False) - exit(e.returncode) + sys.exit(e.returncode) diff --git a/.github/scripts/check_labels.py b/.github/scripts/check_labels.py index dd6d441224f..44122833c91 100755 --- a/.github/scripts/check_labels.py +++ b/.github/scripts/check_labels.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 """Check whether a PR has required labels.""" +import sys from typing import Any from github_utils import gh_delete_comment, gh_post_pr_comment @@ -46,7 +47,7 @@ def main() -> None: except Exception as e: pass - exit(0) + sys.exit(0) if __name__ == "__main__": diff --git a/benchmarks/dynamo/benchmarks.py b/benchmarks/dynamo/benchmarks.py index cb4cc84867c..c209781ffd6 100755 --- a/benchmarks/dynamo/benchmarks.py +++ b/benchmarks/dynamo/benchmarks.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import argparse import os +import sys from typing import Set @@ -87,7 +88,7 @@ if __name__ == "__main__": torchbench.torchbench_main() else: print(f"Illegal model name? {name}") - exit(-1) + sys.exit(-1) else: import torchbench diff --git a/caffe2/python/models/download.py b/caffe2/python/models/download.py index 895f87a4e45..f7e0234cde8 100644 --- a/caffe2/python/models/download.py +++ b/caffe2/python/models/download.py @@ -25,7 +25,7 @@ DOWNLOAD_COLUMNS = 70 # Don't let urllib hang up on big downloads def signalHandler(signal, frame): print("Killing download...") - exit(0) + sys.exit(0) signal.signal(signal.SIGINT, signalHandler) @@ -107,7 +107,7 @@ def downloadModel(model, args): response = input(query) if response.upper() == 'N' or not response: print("Cancelling download...") - exit(0) + sys.exit(0) print("Overwriting existing folder! ({filename})".format(filename=model_folder)) deleteDirectory(model_folder) @@ -122,7 +122,7 @@ def downloadModel(model, args): print("Abort: {reason}".format(reason=str(e))) print("Cleaning up...") deleteDirectory(model_folder) - exit(0) + sys.exit(0) if args.install: os.symlink("{folder}/__sym_init__.py".format(folder=dir_path), diff --git a/pyproject.toml b/pyproject.toml index eb764cb895f..38f55c38ece 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,8 @@ select = [ "PIE807", "PIE810", "PLE", + "PLR1722", # use sys exit + "PLW3301", # nested min max "RUF017", "TRY302", ] diff --git a/scripts/model_zoo/update-caffe2-models.py b/scripts/model_zoo/update-caffe2-models.py index aad5ec58e5d..309c76881f5 100755 --- a/scripts/model_zoo/update-caffe2-models.py +++ b/scripts/model_zoo/update-caffe2-models.py @@ -37,7 +37,7 @@ class SomeClass: print(f"Abort: {e}") print("Cleaning up...") deleteDirectory(model_dir) - exit(1) + sys.exit(1) def _caffe2_model_dir(self, model): caffe2_home = os.path.expanduser("~/.caffe2") diff --git a/test/distributions/test_transforms.py b/test/distributions/test_transforms.py index 6fd4cf818d6..54fd13d9ece 100644 --- a/test/distributions/test_transforms.py +++ b/test/distributions/test_transforms.py @@ -418,7 +418,7 @@ def test_compose_affine(event_dims): if transform.domain.event_dim > 1: base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1)) dist = TransformedDistribution(base_dist, transforms) - assert dist.support.event_dim == max(1, max(event_dims)) + assert dist.support.event_dim == max(1, *event_dims) @pytest.mark.parametrize("batch_shape", [(), (6,), (5, 4)], ids=str) diff --git a/tools/linter/adapters/actionlint_linter.py b/tools/linter/adapters/actionlint_linter.py index 169451ca1ce..a8f599c2e1e 100644 --- a/tools/linter/adapters/actionlint_linter.py +++ b/tools/linter/adapters/actionlint_linter.py @@ -5,6 +5,7 @@ import logging import os import re import subprocess +import sys import time from enum import Enum from typing import List, NamedTuple, Optional, Pattern @@ -131,7 +132,7 @@ if __name__ == "__main__": ), ) print(json.dumps(err_msg._asdict()), flush=True) - exit(0) + sys.exit(0) with concurrent.futures.ThreadPoolExecutor( max_workers=os.cpu_count(), diff --git a/tools/linter/adapters/bazel_linter.py b/tools/linter/adapters/bazel_linter.py index 9cc15caa0e3..f3c4a95c6e4 100644 --- a/tools/linter/adapters/bazel_linter.py +++ b/tools/linter/adapters/bazel_linter.py @@ -10,6 +10,7 @@ import json import re import shlex import subprocess +import sys import xml.etree.ElementTree as ET from enum import Enum from typing import List, NamedTuple, Optional, Set @@ -182,7 +183,7 @@ def main() -> None: description=(f"Failed due to {e.__class__.__name__}:\n{e}"), ) print(json.dumps(err_msg._asdict()), flush=True) - exit(0) + sys.exit(0) for filename in args.filenames: for lint_message in check_bazel(filename, disallowed_checksums): diff --git a/tools/linter/adapters/clangtidy_linter.py b/tools/linter/adapters/clangtidy_linter.py index 5a296166a90..43947976743 100644 --- a/tools/linter/adapters/clangtidy_linter.py +++ b/tools/linter/adapters/clangtidy_linter.py @@ -249,7 +249,7 @@ def main() -> None: ), ) print(json.dumps(err_msg._asdict()), flush=True) - exit(0) + sys.exit(0) abs_build_dir = Path(args.build_dir).resolve() diff --git a/tools/linter/adapters/grep_linter.py b/tools/linter/adapters/grep_linter.py index 64dac4cdc07..168800eb447 100644 --- a/tools/linter/adapters/grep_linter.py +++ b/tools/linter/adapters/grep_linter.py @@ -252,7 +252,7 @@ def main() -> None: ), ) print(json.dumps(err_msg._asdict()), flush=True) - exit(0) + sys.exit(0) lines = proc.stdout.decode().splitlines() for line in lines: diff --git a/tools/linter/adapters/lintrunner_version_linter.py b/tools/linter/adapters/lintrunner_version_linter.py index dc9828e8d75..48eab1a39a8 100644 --- a/tools/linter/adapters/lintrunner_version_linter.py +++ b/tools/linter/adapters/lintrunner_version_linter.py @@ -1,5 +1,6 @@ import json import subprocess +import sys from enum import Enum from typing import NamedTuple, Optional, Tuple @@ -53,7 +54,7 @@ if __name__ == "__main__": replacement=None, description="Lintrunner is not installed, did you forget to run `make setup_lint && make lint`?", ) - exit(0) + sys.exit(0) curr_version = int(version_match[1]), int(version_match[2]), int(version_match[3]) min_version = (0, 10, 7) diff --git a/tools/linter/adapters/s3_init.py b/tools/linter/adapters/s3_init.py index 7befbf27f90..260a3d4394b 100644 --- a/tools/linter/adapters/s3_init.py +++ b/tools/linter/adapters/s3_init.py @@ -205,7 +205,7 @@ if __name__ == "__main__": # If the host platform is not in platform_to_hash, it is unsupported. if host_platform not in config: logging.error("Unsupported platform: %s/%s", HOST_PLATFORM, HOST_PLATFORM_ARCH) - exit(1) + sys.exit(1) url = config[host_platform]["download_url"] hash = config[host_platform]["hash"] diff --git a/tools/linter/adapters/shellcheck_linter.py b/tools/linter/adapters/shellcheck_linter.py index bcf0b2a517b..5d8d7a8052e 100644 --- a/tools/linter/adapters/shellcheck_linter.py +++ b/tools/linter/adapters/shellcheck_linter.py @@ -3,6 +3,7 @@ import json import logging import shutil import subprocess +import sys import time from enum import Enum from typing import List, NamedTuple, Optional @@ -108,7 +109,7 @@ if __name__ == "__main__": description="shellcheck is not installed, did you forget to run `lintrunner init`?", ) print(json.dumps(err_msg._asdict()), flush=True) - exit(0) + sys.exit(0) args = parser.parse_args() diff --git a/tools/linter/clang_tidy/generate_build_files.py b/tools/linter/clang_tidy/generate_build_files.py index 7e56ecb6d3b..692bbbe5456 100644 --- a/tools/linter/clang_tidy/generate_build_files.py +++ b/tools/linter/clang_tidy/generate_build_files.py @@ -18,7 +18,7 @@ def run_cmd(cmd: List[str]) -> None: print(stderr) if result.returncode != 0: print(f"Failed to run {cmd}") - exit(1) + sys.exit(1) def update_submodules() -> None: diff --git a/tools/test/test_heuristics.py b/tools/test/test_heuristics.py index 0e135848fd2..380297ac09d 100644 --- a/tools/test/test_heuristics.py +++ b/tools/test/test_heuristics.py @@ -21,7 +21,7 @@ try: except ModuleNotFoundError: print("Can't import required modules, exiting") - exit(1) + sys.exit(1) def mocked_file(contents: Dict[Any, Any]) -> io.IOBase: diff --git a/tools/test/test_test_selections.py b/tools/test/test_test_selections.py index f8aa78b9946..1a06cb20675 100644 --- a/tools/test/test_test_selections.py +++ b/tools/test/test_test_selections.py @@ -12,7 +12,7 @@ try: from tools.testing.test_selections import calculate_shards, ShardedTest, THRESHOLD except ModuleNotFoundError: print("Can't import required modules, exiting") - exit(1) + sys.exit(1) class TestCalculateShards(unittest.TestCase): @@ -308,7 +308,7 @@ class TestCalculateShards(unittest.TestCase): if k != "super_long_test" and k != "long_test1" ] sum_of_rest = sum(rest_of_tests) - random_times["super_long_test"] = max(sum_of_rest / 2, max(rest_of_tests)) + random_times["super_long_test"] = max(sum_of_rest / 2, *rest_of_tests) random_times["long_test1"] = sum_of_rest - random_times["super_long_test"] # An optimal sharding would look like the below, but we don't need to compute this for the test: # optimal_shards = [ diff --git a/torch/backends/xeon/run_cpu.py b/torch/backends/xeon/run_cpu.py index accfea49873..26d3df22be7 100644 --- a/torch/backends/xeon/run_cpu.py +++ b/torch/backends/xeon/run_cpu.py @@ -500,7 +500,7 @@ socket", ncore_per_node, args.ncores_per_instance, ) - exit(-1) + sys.exit(-1) elif num_leftover_cores == 0: # aren't any cross-node cores logger.info( @@ -573,7 +573,7 @@ won't take effect even if it is set explicitly." "Core binding with numactl is not available, and --disable_taskset is set. \ Please unset --disable_taskset to use taskset instead of numactl." ) - exit(-1) + sys.exit(-1) if not args.disable_taskset: enable_taskset = True diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 42cd3612759..250cb78a941 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2083,7 +2083,7 @@ def sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad= size = [1, 5, 10] for batch, m, n in product(batches, size, size): - for k in range(min(3, min(m, n))): + for k in range(min(3, m, n)): a = make_arg((*batch, m, k)) b = make_arg((*batch, n, k)) yield SampleInput(a, b, **kwargs) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 61f72026ae3..0e5393bb5cf 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -934,11 +934,11 @@ def run_tests(argv=UNITTEST_ARGS): if not RERUN_DISABLED_TESTS: # exitcode of 5 means no tests were found, which happens since some test configs don't # run tests from certain files - exit(0 if exit_code == 5 else exit_code) + sys.exit(0 if exit_code == 5 else exit_code) else: # Only record the test report and always return a success code when running under rerun # disabled tests mode - exit(0) + sys.exit(0) elif TEST_SAVE_XML is not None: # import here so that non-CI doesn't need xmlrunner installed import xmlrunner # type: ignore[import] diff --git a/torch/testing/_internal/opinfo/definitions/linalg.py b/torch/testing/_internal/opinfo/definitions/linalg.py index b1677bf5821..033bf4100be 100644 --- a/torch/testing/_internal/opinfo/definitions/linalg.py +++ b/torch/testing/_internal/opinfo/definitions/linalg.py @@ -535,7 +535,7 @@ def sample_inputs_linalg_pinv_singular( size = [0, 3, 50] for batch, m, n in product(batches, size, size): - for k in range(min(3, min(m, n))): + for k in range(min(3, m, n)): # Note that by making the columns of `a` and `b` orthonormal we make sure that # the product matrix `a @ b.t()` has condition number 1 when restricted to its image a = (