From 7d2a18da0b3427fcbe44b461a0aa508194535885 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 24 Apr 2023 08:37:13 -0700 Subject: [PATCH] Enable ruff in lintrunner (#99785) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### This change - Implements the ruff linter in pytorch lintrunner. It is adapted from https://github.com/justinchuby/lintrunner-adapters/blob/main/lintrunner_adapters/adapters/ruff_linter.py. It does **both linting and fixing**. ๐Ÿ”ง - Migrated all flake8 configs to the ruff config and enabled it for the repo. โœ… - **`ruff` lints the whole repo in under 2s** ๐Ÿคฏ Fixes https://github.com/pytorch/pytorch/issues/94737 Replaces #99280 @huydhn @Skylion007 ### ๐Ÿค– Generated by Copilot at 6b982dd ### Summary ๐Ÿงน๐Ÿ› ๏ธ๐ŸŽจ Add `[tool.ruff]` section to `pyproject.toml` to configure `ruff` code formatter and linter. This change aims to improve code quality and consistency with a single tool. > _`ruff` cleans the code_ > _like a spring breeze in the fields_ > _`pyproject.toml`_ ### Walkthrough * Configure `ruff` code formatter and linter for the whole project ([link](https://github.com/pytorch/pytorch/pull/99785/files?diff=unified&w=0#diff-50c86b7ed8ac2cf95bd48334961bf0530cdc77b5a56f852c5c61b89d735fd711R22-R79)) Pull Request resolved: https://github.com/pytorch/pytorch/pull/99785 Approved by: https://github.com/malfet, https://github.com/Skylion007 --- .flake8 | 2 + .lintrunner.toml | 26 + pyproject.toml | 60 +++ tools/linter/adapters/ruff_linter.py | 462 ++++++++++++++++++ tools/test/test_selective_build.py | 2 +- torch/distributed/pipeline/sync/microbatch.py | 2 +- 6 files changed, 552 insertions(+), 2 deletions(-) create mode 100644 tools/linter/adapters/ruff_linter.py diff --git a/.flake8 b/.flake8 index 4d1aa57ef74..8ac73363e62 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,6 @@ [flake8] +# NOTE: **Mirror any changes** to this file the [tool.ruff] config in pyproject.toml +# before we can fully move to use ruff enable-extensions = G select = B,C,E,F,G,P,SIM1,T4,W,B9 max-line-length = 120 diff --git a/.lintrunner.toml b/.lintrunner.toml index 10712537a7d..0425ebd5678 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -972,3 +972,29 @@ command = [ 'python3', 'tools/linter/adapters/lintrunner_version_linter.py' ] + +[[linter]] +code = 'RUFF' +include_patterns = ['**/*.py'] +exclude_patterns = [ + 'caffe2/**', + 'functorch/docs/**', + 'functorch/notebooks/**', + 'scripts/**', + 'third_party/**', +] +command = [ + 'python3', + 'tools/linter/adapters/ruff_linter.py', + '--config=pyproject.toml', + '--show-disable', + '--', + '@{{PATHSFILE}}' +] +init_command = [ + 'python3', + 'tools/linter/adapters/pip_init.py', + '--dry-run={{DRYRUN}}', + 'ruff==0.0.262', +] +is_formatter = true diff --git a/pyproject.toml b/pyproject.toml index 4c39bb60336..c74e413096d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,3 +19,63 @@ build-backend = "setuptools.build_meta:__legacy__" # Uncomment if pyproject.toml worked fine to ensure consistency with flake8 # line-length = 120 target-version = ["py38", "py39", "py310", "py311"] + + +[tool.ruff] +target-version = "py38" + +# NOTE: Synchoronize the ignores with .flake8 +ignore = [ + # these ignores are from flake8-bugbear; please fix! + "B007", "B008", "B017", + "B018", # Useless expression + "B019", "B020", + "B022", # Allow empty context manager + "B023", "B024", "B026", + "B028", # No explicit `stacklevel` keyword argument found + "B027", "B904", "B905", + "E402", + "C408", # C408 ignored because we like the dict keyword argument syntax + "C419", # generators may not be supported by jit + "E501", # E501 is not flexible enough, we're using B950 instead + "E721", + "E731", # Assign lambda expression + "E741", + "EXE001", + "F405", + "F821", + "F841", + # these ignores are from flake8-logging-format; please fix! + "G101", "G201", "G202", + "SIM102", "SIM103", "SIM112", # flake8-simplify code styles + "SIM105", # these ignores are from flake8-simplify. please fix or ignore with commented reason + "SIM108", + "SIM109", + "SIM110", + "SIM114", # Combine `if` branches using logical `or` operator + "SIM115", + "SIM116", # Disable Use a dictionary instead of consecutive `if` statements + "SIM117", + "SIM118", +] +line-length = 120 +select = [ + "B", + "C4", + "G", + "E", + "F", + "SIM1", + "W", +] + +[tool.ruff.per-file-ignores] +"__init__.py" = ["F401"] +"torchgen/api/types/__init__.py" = [ + "F401", + "F403", +] +"torchgen/executorch/api/types/__init__.py" = [ + "F401", + "F403", +] diff --git a/tools/linter/adapters/ruff_linter.py b/tools/linter/adapters/ruff_linter.py new file mode 100644 index 00000000000..451834aa7c3 --- /dev/null +++ b/tools/linter/adapters/ruff_linter.py @@ -0,0 +1,462 @@ +"""Adapter for https://github.com/charliermarsh/ruff.""" + +from __future__ import annotations + +import argparse +import concurrent.futures +import dataclasses +import enum +import json +import logging +import os +import subprocess +import sys +import time +from typing import Any, BinaryIO + +LINTER_CODE = "RUFF" +IS_WINDOWS: bool = os.name == "nt" + + +def eprint(*args: Any, **kwargs: Any) -> None: + """Print to stderr.""" + print(*args, file=sys.stderr, flush=True, **kwargs) + + +class LintSeverity(str, enum.Enum): + """Severity of a lint message.""" + + ERROR = "error" + WARNING = "warning" + ADVICE = "advice" + DISABLED = "disabled" + + +@dataclasses.dataclass(frozen=True) +class LintMessage: + """A lint message defined by https://docs.rs/lintrunner/latest/lintrunner/lint_message/struct.LintMessage.html.""" + + path: str | None + line: int | None + char: int | None + code: str + severity: LintSeverity + name: str + original: str | None + replacement: str | None + description: str | None + + def asdict(self) -> dict[str, Any]: + return dataclasses.asdict(self) + + def display(self) -> None: + """Print to stdout for lintrunner to consume.""" + print(json.dumps(self.asdict()), flush=True) + + +def as_posix(name: str) -> str: + return name.replace("\\", "/") if IS_WINDOWS else name + + +def _run_command( + args: list[str], + *, + timeout: int | None, + stdin: BinaryIO | None, + input: bytes | None, + check: bool, + cwd: os.PathLike[Any] | None, +) -> subprocess.CompletedProcess[bytes]: + logging.debug("$ %s", " ".join(args)) + start_time = time.monotonic() + try: + if input is not None: + return subprocess.run( + args, + capture_output=True, + shell=False, + input=input, + timeout=timeout, + check=check, + cwd=cwd, + ) + + return subprocess.run( + args, + stdin=stdin, + capture_output=True, + shell=False, + timeout=timeout, + check=check, + cwd=cwd, + ) + finally: + end_time = time.monotonic() + logging.debug("took %dms", (end_time - start_time) * 1000) + + +def run_command( + args: list[str], + *, + retries: int = 0, + timeout: int | None = None, + stdin: BinaryIO | None = None, + input: bytes | None = None, + check: bool = False, + cwd: os.PathLike[Any] | None = None, +) -> subprocess.CompletedProcess[bytes]: + remaining_retries = retries + while True: + try: + return _run_command( + args, timeout=timeout, stdin=stdin, input=input, check=check, cwd=cwd + ) + except subprocess.TimeoutExpired as err: + if remaining_retries == 0: + raise err + remaining_retries -= 1 + logging.warning( + "(%s/%s) Retrying because command failed with: %r", + retries - remaining_retries, + retries, + err, + ) + time.sleep(1) + + +def add_default_options(parser: argparse.ArgumentParser) -> None: + """Add default options to a parser. + + This should be called the last in the chain of add_argument calls. + """ + parser.add_argument( + "--retries", + type=int, + default=3, + help="number of times to retry if the linter times out.", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="verbose logging", + ) + parser.add_argument( + "filenames", + nargs="+", + help="paths to lint", + ) + + +def explain_rule(code: str) -> str: + proc = run_command( + ["ruff", "rule", "--format=json", code], + check=True, + ) + rule = json.loads(str(proc.stdout, "utf-8").strip()) + return f"\n{rule['linter']}: {rule['summary']}" + + +def get_issue_severity(code: str) -> LintSeverity: + # "B901": `return x` inside a generator + # "B902": Invalid first argument to a method + # "B903": __slots__ efficiency + # "B950": Line too long + # "C4": Flake8 Comprehensions + # "C9": Cyclomatic complexity + # "E2": PEP8 horizontal whitespace "errors" + # "E3": PEP8 blank line "errors" + # "E5": PEP8 line length "errors" + # "T400": type checking Notes + # "T49": internal type checker errors or unmatched messages + if any( + code.startswith(x) + for x in ( + "B9", + "C4", + "C9", + "E2", + "E3", + "E5", + "T400", + "T49", + "PLC", + "PLR", + ) + ): + return LintSeverity.ADVICE + + # "F821": Undefined name + # "E999": syntax error + if any(code.startswith(x) for x in ("F821", "E999", "PLE")): + return LintSeverity.ERROR + + # "F": PyFlakes Error + # "B": flake8-bugbear Error + # "E": PEP8 "Error" + # "W": PEP8 Warning + # possibly other plugins... + return LintSeverity.WARNING + + +def format_lint_message( + message: str, code: str, rules: dict[str, str], show_disable: bool +) -> str: + if rules: + message += f".\n{rules.get(code) or ''}" + message += ".\nSee https://beta.ruff.rs/docs/rules/" + if show_disable: + message += f".\n\nTo disable, use ` # noqa: {code}`" + return message + + +def check_files( + filenames: list[str], + severities: dict[str, LintSeverity], + *, + config: str | None, + retries: int, + timeout: int, + explain: bool, + show_disable: bool, +) -> list[LintMessage]: + try: + proc = run_command( + [ + sys.executable, + "-m", + "ruff", + "--exit-zero", + "--quiet", + "--format=json", + *([f"--config={config}"] if config else []), + *filenames, + ], + retries=retries, + timeout=timeout, + check=True, + ) + except (OSError, subprocess.CalledProcessError) as err: + return [ + LintMessage( + path=None, + line=None, + char=None, + code=LINTER_CODE, + severity=LintSeverity.ERROR, + name="command-failed", + original=None, + replacement=None, + description=( + f"Failed due to {err.__class__.__name__}:\n{err}" + if not isinstance(err, subprocess.CalledProcessError) + else ( + f"COMMAND (exit code {err.returncode})\n" + f"{' '.join(as_posix(x) for x in err.cmd)}\n\n" + f"STDERR\n{err.stderr.decode('utf-8').strip() or '(empty)'}\n\n" + f"STDOUT\n{err.stdout.decode('utf-8').strip() or '(empty)'}" + ) + ), + ) + ] + + stdout = str(proc.stdout, "utf-8").strip() + vulnerabilities = json.loads(stdout) + + if explain: + all_codes = {v["code"] for v in vulnerabilities} + rules = {code: explain_rule(code) for code in all_codes} + else: + rules = {} + + return [ + LintMessage( + path=vuln["filename"], + name=vuln["code"], + description=( + format_lint_message( + vuln["message"], + vuln["code"], + rules, + show_disable, + ) + ), + line=int(vuln["location"]["row"]), + char=int(vuln["location"]["column"]), + code=LINTER_CODE, + severity=severities.get(vuln["code"], get_issue_severity(vuln["code"])), + original=None, + replacement=None, + ) + for vuln in vulnerabilities + ] + + +def check_file_for_fixes( + filename: str, + *, + config: str | None, + retries: int, + timeout: int, +) -> list[LintMessage]: + try: + with open(filename, "rb") as f: + original = f.read() + with open(filename, "rb") as f: + proc_fix = run_command( + [ + sys.executable, + "-m", + "ruff", + "--fix-only", + "--exit-zero", + *([f"--config={config}"] if config else []), + "--stdin-filename", + filename, + "-", + ], + stdin=f, + retries=retries, + timeout=timeout, + check=True, + ) + except (OSError, subprocess.CalledProcessError) as err: + return [ + LintMessage( + path=None, + line=None, + char=None, + code=LINTER_CODE, + severity=LintSeverity.ERROR, + name="command-failed", + original=None, + replacement=None, + description=( + f"Failed due to {err.__class__.__name__}:\n{err}" + if not isinstance(err, subprocess.CalledProcessError) + else ( + f"COMMAND (exit code {err.returncode})\n" + f"{' '.join(as_posix(x) for x in err.cmd)}\n\n" + f"STDERR\n{err.stderr.decode('utf-8').strip() or '(empty)'}\n\n" + f"STDOUT\n{err.stdout.decode('utf-8').strip() or '(empty)'}" + ) + ), + ) + ] + + replacement = proc_fix.stdout + if original == replacement: + return [] + + return [ + LintMessage( + path=filename, + name="format", + description="Run `lintrunner -a` to apply this patch.", + line=None, + char=None, + code=LINTER_CODE, + severity=LintSeverity.WARNING, + original=original.decode("utf-8"), + replacement=replacement.decode("utf-8"), + ) + ] + + +def main() -> None: + parser = argparse.ArgumentParser( + description=f"Ruff linter. Linter code: {LINTER_CODE}. Use with RUFF-FIX to auto-fix issues.", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "--config", + default=None, + help="Path to the `pyproject.toml` or `ruff.toml` file to use for configuration", + ) + parser.add_argument( + "--explain", + action="store_true", + help="Explain a rule", + ) + parser.add_argument( + "--show-disable", + action="store_true", + help="Show how to disable a lint message", + ) + parser.add_argument( + "--timeout", + default=90, + type=int, + help="Seconds to wait for ruff", + ) + parser.add_argument( + "--severity", + action="append", + help="map code to severity (e.g. `F401:advice`). This option can be used multiple times.", + ) + parser.add_argument( + "--no-fix", + action="store_true", + help="Do not suggest fixes", + ) + add_default_options(parser) + args = parser.parse_args() + + logging.basicConfig( + format="<%(threadName)s:%(levelname)s> %(message)s", + level=logging.NOTSET + if args.verbose + else logging.DEBUG + if len(args.filenames) < 1000 + else logging.INFO, + stream=sys.stderr, + ) + + severities: dict[str, LintSeverity] = {} + if args.severity: + for severity in args.severity: + parts = severity.split(":", 1) + assert len(parts) == 2, f"invalid severity `{severity}`" + severities[parts[0]] = LintSeverity(parts[1]) + + lint_messages = check_files( + args.filenames, + severities=severities, + config=args.config, + retries=args.retries, + timeout=args.timeout, + explain=args.explain, + show_disable=args.show_disable, + ) + for lint_message in lint_messages: + lint_message.display() + + if args.no_fix or not lint_messages: + # If we're not fixing, we can exit early + return + + files_with_lints = {lint.path for lint in lint_messages if lint.path is not None} + with concurrent.futures.ThreadPoolExecutor( + max_workers=os.cpu_count(), + thread_name_prefix="Thread", + ) as executor: + futures = { + executor.submit( + check_file_for_fixes, + path, + config=args.config, + retries=args.retries, + timeout=args.timeout, + ): path + for path in files_with_lints + } + for future in concurrent.futures.as_completed(futures): + try: + for lint_message in future.result(): + lint_message.display() + except Exception: # Catch all exceptions for lintrunner + logging.critical('Failed at "%s".', futures[future]) + raise + + +if __name__ == "__main__": + main() diff --git a/tools/test/test_selective_build.py b/tools/test/test_selective_build.py index 4b96ec98d39..a1f011cfcfd 100644 --- a/tools/test/test_selective_build.py +++ b/tools/test/test_selective_build.py @@ -1,6 +1,6 @@ import unittest -from torchgen.selective_build.operator import * +from torchgen.selective_build.operator import * # noqa: F403 from torchgen.model import Location, NativeFunction from torchgen.selective_build.selector import ( combine_selective_builders, diff --git a/torch/distributed/pipeline/sync/microbatch.py b/torch/distributed/pipeline/sync/microbatch.py index c25747046d1..5b8aca25754 100644 --- a/torch/distributed/pipeline/sync/microbatch.py +++ b/torch/distributed/pipeline/sync/microbatch.py @@ -143,7 +143,7 @@ class Batch: self._values = value def _setitem_by_slice(self, index: slice, value) -> None: - if not (index.start is index.stop is index.step is None): + if not (index.start is index.stop is index.step is None): # noqa: E714 raise NotImplementedError("only slice [:] supported") if not self.atomic: