mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[dynamo] Add a lint rule to restrict what 3P library one can import (#143312)
As title, this patch prevents developers from importing third party
libraries to patch things in Dynamo, unless there's no other easy
workaround (in which case one would add the library to the allowlist in
`import_linter.py`, as instructed by the lint error).
For instance, if we remove `einops` from the allowlist, we'd get this
```verbatim
>>> Lint for torch/_dynamo/decorators.py:
Error (IMPORT) Disallowed import
importing from einops is not allowed, if you believe there's a valid
reason, please add it to import_linter.py
608 |# Note: this carefully avoids eagerly import einops.
609 |# TODO: we should delete this whole _allow_in_graph_einops logic by approximately 2024 Q2
610 |def _allow_in_graph_einops():
>>> 611 | import einops
612 |
613 | try:
614 | # requires einops > 0.6.1, torch >= 2.0
Error (IMPORT) Disallowed import
importing from einops is not allowed, if you believe there's a valid
reason, please add it to import_linter.py
612 |
613 | try:
614 | # requires einops > 0.6.1, torch >= 2.0
>>> 615 | from einops._torch_specific import ( # type: ignore[attr-defined] # noqa: F401
616 | _ops_were_registered_in_torchdynamo,
617 | )
618 |
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143312
Approved by: https://github.com/zou3519
This commit is contained in:
parent
8e78345d69
commit
629de4da60
2 changed files with 425 additions and 0 deletions
|
|
@ -1733,3 +1733,17 @@ include_patterns = [
|
|||
'torch/**/not-exist.py'
|
||||
]
|
||||
is_formatter = false
|
||||
|
||||
# `import_linter` reports on importing disallowed third party libraries.
|
||||
[[linter]]
|
||||
code = 'IMPORT_LINTER'
|
||||
command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/import_linter.py',
|
||||
'--',
|
||||
'@{{PATHSFILE}}'
|
||||
]
|
||||
include_patterns = [
|
||||
'torch/_dynamo/**',
|
||||
]
|
||||
is_formatter = false
|
||||
|
|
|
|||
411
tools/linter/adapters/import_linter.py
Normal file
411
tools/linter/adapters/import_linter.py
Normal file
|
|
@ -0,0 +1,411 @@
|
|||
"""
|
||||
Checks files to make sure there are no imports from disallowed third party
|
||||
libraries.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import token
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, NamedTuple, Set, TYPE_CHECKING
|
||||
|
||||
|
||||
_PARENT = Path(__file__).parent.absolute()
|
||||
_PATH = [Path(p).absolute() for p in sys.path]
|
||||
|
||||
if TYPE_CHECKING or _PARENT not in _PATH:
|
||||
from . import _linter
|
||||
else:
|
||||
import _linter
|
||||
|
||||
|
||||
class LintSeverity(str, Enum):
|
||||
ERROR = "error"
|
||||
WARNING = "warning"
|
||||
ADVICE = "advice"
|
||||
DISABLED = "disabled"
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
LINTER_CODE = "NEWLINE"
|
||||
CURRENT_FILE_NAME = os.path.basename(__file__)
|
||||
_MODULE_NAME_ALLOW_LIST: Set[str] = set()
|
||||
|
||||
# Add builtin modules.
|
||||
if sys.version_info >= (3, 10):
|
||||
_MODULE_NAME_ALLOW_LIST.update(sys.stdlib_module_names)
|
||||
else:
|
||||
assert (sys.version_info.major, sys.version_info.minor) == (3, 9)
|
||||
# Taken from `stdlib_list("3.9")` to avoid introducing a new dependency.
|
||||
_MODULE_NAME_ALLOW_LIST.update(
|
||||
[
|
||||
"__future__",
|
||||
"_abc",
|
||||
"_aix_support",
|
||||
"_ast",
|
||||
"_bootlocale",
|
||||
"_bootsubprocess",
|
||||
"_codecs",
|
||||
"_collections",
|
||||
"_collections_abc",
|
||||
"_compat_pickle",
|
||||
"_compression",
|
||||
"_crypt",
|
||||
"_functools",
|
||||
"_hashlib",
|
||||
"_imp",
|
||||
"_io",
|
||||
"_locale",
|
||||
"_lsprof",
|
||||
"_markupbase",
|
||||
"_operator",
|
||||
"_osx_support",
|
||||
"_peg_parser",
|
||||
"_posixsubprocess",
|
||||
"_py_abc",
|
||||
"_pydecimal",
|
||||
"_pyio",
|
||||
"_random",
|
||||
"_signal",
|
||||
"_sitebuiltins",
|
||||
"_socket",
|
||||
"_sre",
|
||||
"_ssl",
|
||||
"_stat",
|
||||
"_string",
|
||||
"_strptime",
|
||||
"_symtable",
|
||||
"_sysconfigdata_x86_64_conda_cos6_linux_gnu",
|
||||
"_sysconfigdata_x86_64_conda_linux_gnu",
|
||||
"_thread",
|
||||
"_threading_local",
|
||||
"_tracemalloc",
|
||||
"_uuid",
|
||||
"_warnings",
|
||||
"_weakref",
|
||||
"_weakrefset",
|
||||
"abc",
|
||||
"aifc",
|
||||
"antigravity",
|
||||
"argparse",
|
||||
"array",
|
||||
"ast",
|
||||
"asynchat",
|
||||
"asyncio",
|
||||
"asyncore",
|
||||
"atexit",
|
||||
"audioop",
|
||||
"base64",
|
||||
"bdb",
|
||||
"binascii",
|
||||
"binhex",
|
||||
"bisect",
|
||||
"builtins",
|
||||
"bz2",
|
||||
"cProfile",
|
||||
"calendar",
|
||||
"cgi",
|
||||
"cgitb",
|
||||
"chunk",
|
||||
"cmath",
|
||||
"cmd",
|
||||
"code",
|
||||
"codecs",
|
||||
"codeop",
|
||||
"collections",
|
||||
"colorsys",
|
||||
"compileall",
|
||||
"concurrent",
|
||||
"configparser",
|
||||
"contextlib",
|
||||
"contextvars",
|
||||
"copy",
|
||||
"copyreg",
|
||||
"crypt",
|
||||
"csv",
|
||||
"ctypes",
|
||||
"curses",
|
||||
"dataclasses",
|
||||
"datetime",
|
||||
"dbm",
|
||||
"decimal",
|
||||
"difflib",
|
||||
"dis",
|
||||
"distutils",
|
||||
"doctest",
|
||||
"email",
|
||||
"encodings",
|
||||
"ensurepip",
|
||||
"enum",
|
||||
"errno",
|
||||
"faulthandler",
|
||||
"fcntl",
|
||||
"filecmp",
|
||||
"fileinput",
|
||||
"fnmatch",
|
||||
"formatter",
|
||||
"fractions",
|
||||
"ftplib",
|
||||
"functools",
|
||||
"gc",
|
||||
"genericpath",
|
||||
"getopt",
|
||||
"getpass",
|
||||
"gettext",
|
||||
"glob",
|
||||
"graphlib",
|
||||
"grp",
|
||||
"gzip",
|
||||
"hashlib",
|
||||
"heapq",
|
||||
"hmac",
|
||||
"html",
|
||||
"http",
|
||||
"idlelib",
|
||||
"imaplib",
|
||||
"imghdr",
|
||||
"imp",
|
||||
"importlib",
|
||||
"inspect",
|
||||
"io",
|
||||
"ipaddress",
|
||||
"itertools",
|
||||
"json",
|
||||
"keyword",
|
||||
"lib2to3",
|
||||
"linecache",
|
||||
"locale",
|
||||
"logging",
|
||||
"lzma",
|
||||
"mailbox",
|
||||
"mailcap",
|
||||
"marshal",
|
||||
"math",
|
||||
"mimetypes",
|
||||
"mmap",
|
||||
"modulefinder",
|
||||
"msilib",
|
||||
"msvcrt",
|
||||
"multiprocessing",
|
||||
"netrc",
|
||||
"nis",
|
||||
"nntplib",
|
||||
"ntpath",
|
||||
"nturl2path",
|
||||
"numbers",
|
||||
"opcode",
|
||||
"operator",
|
||||
"optparse",
|
||||
"os",
|
||||
"ossaudiodev",
|
||||
"parser",
|
||||
"pathlib",
|
||||
"pdb",
|
||||
"pickle",
|
||||
"pickletools",
|
||||
"pipes",
|
||||
"pkgutil",
|
||||
"platform",
|
||||
"plistlib",
|
||||
"poplib",
|
||||
"posix",
|
||||
"posixpath",
|
||||
"pprint",
|
||||
"profile",
|
||||
"pstats",
|
||||
"pty",
|
||||
"pwd",
|
||||
"py_compile",
|
||||
"pyclbr",
|
||||
"pydoc",
|
||||
"pydoc_data",
|
||||
"queue",
|
||||
"quopri",
|
||||
"random",
|
||||
"re",
|
||||
"readline",
|
||||
"reprlib",
|
||||
"resource",
|
||||
"rlcompleter",
|
||||
"runpy",
|
||||
"sched",
|
||||
"secrets",
|
||||
"select",
|
||||
"selectors",
|
||||
"shelve",
|
||||
"shlex",
|
||||
"shutil",
|
||||
"signal",
|
||||
"site",
|
||||
"smtpd",
|
||||
"smtplib",
|
||||
"sndhdr",
|
||||
"socket",
|
||||
"socketserver",
|
||||
"spwd",
|
||||
"sqlite3",
|
||||
"sre_compile",
|
||||
"sre_constants",
|
||||
"sre_parse",
|
||||
"ssl",
|
||||
"stat",
|
||||
"statistics",
|
||||
"string",
|
||||
"stringprep",
|
||||
"struct",
|
||||
"subprocess",
|
||||
"sunau",
|
||||
"symbol",
|
||||
"symtable",
|
||||
"sys",
|
||||
"sysconfig",
|
||||
"syslog",
|
||||
"tabnanny",
|
||||
"tarfile",
|
||||
"telnetlib",
|
||||
"tempfile",
|
||||
"termios",
|
||||
"test",
|
||||
"textwrap",
|
||||
"this",
|
||||
"threading",
|
||||
"time",
|
||||
"timeit",
|
||||
"tkinter",
|
||||
"token",
|
||||
"tokenize",
|
||||
"trace",
|
||||
"traceback",
|
||||
"tracemalloc",
|
||||
"tty",
|
||||
"turtle",
|
||||
"turtledemo",
|
||||
"types",
|
||||
"typing",
|
||||
"unicodedata",
|
||||
"unittest",
|
||||
"urllib",
|
||||
"uu",
|
||||
"uuid",
|
||||
"venv",
|
||||
"warnings",
|
||||
"wave",
|
||||
"weakref",
|
||||
"webbrowser",
|
||||
"winreg",
|
||||
"winsound",
|
||||
"wsgiref",
|
||||
"xdrlib",
|
||||
"xml",
|
||||
"xmlrpc",
|
||||
"xxsubtype",
|
||||
"zipapp",
|
||||
"zipfile",
|
||||
"zipimport",
|
||||
"zlib",
|
||||
"zoneinfo",
|
||||
]
|
||||
)
|
||||
|
||||
# Add the allowed third party libraries. Please avoid updating this unless you
|
||||
# understand the risks -- see `_ERROR_MESSAGE` for why.
|
||||
_MODULE_NAME_ALLOW_LIST.update(
|
||||
[
|
||||
"sympy",
|
||||
"einops",
|
||||
"libfb",
|
||||
"torch",
|
||||
"tvm",
|
||||
"_pytest",
|
||||
"tabulate",
|
||||
"optree",
|
||||
"typing_extensions",
|
||||
"triton",
|
||||
"functorch",
|
||||
"torchrec",
|
||||
"numpy",
|
||||
"torch_xla",
|
||||
]
|
||||
)
|
||||
|
||||
_ERROR_MESSAGE = """
|
||||
Please do not import third-party modules in PyTorch unless they're explicit
|
||||
requirements of PyTorch. Imports of a third-party library may have side effects
|
||||
and other unintentional behavior. If you're just checking if a module exists,
|
||||
use sys.modules.get("torchrec") or the like.
|
||||
"""
|
||||
|
||||
|
||||
def check_file(filepath: str) -> List[LintMessage]:
|
||||
path = Path(filepath)
|
||||
file = _linter.PythonFile("import_linter", path)
|
||||
lint_messages = []
|
||||
for line_number, line_of_tokens in enumerate(file.token_lines):
|
||||
# Skip indents
|
||||
idx = 0
|
||||
for tok in line_of_tokens:
|
||||
if tok.type == token.INDENT:
|
||||
idx += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# Look for either "import foo..." or "from foo..."
|
||||
if idx + 1 < len(line_of_tokens):
|
||||
tok0 = line_of_tokens[idx]
|
||||
tok1 = line_of_tokens[idx + 1]
|
||||
if tok0.type == token.NAME and tok0.string in {"import", "from"}:
|
||||
if tok1.type == token.NAME:
|
||||
module_name = tok1.string
|
||||
if module_name not in _MODULE_NAME_ALLOW_LIST:
|
||||
msg = LintMessage(
|
||||
path=filepath,
|
||||
line=line_number,
|
||||
char=None,
|
||||
code="IMPORT",
|
||||
severity=LintSeverity.ERROR,
|
||||
name="Disallowed import",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=_ERROR_MESSAGE,
|
||||
)
|
||||
lint_messages.append(msg)
|
||||
return lint_messages
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="native functions linter",
|
||||
fromfile_prefix_chars="@",
|
||||
)
|
||||
parser.add_argument(
|
||||
"filepaths",
|
||||
nargs="+",
|
||||
help="paths of files to lint",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Check all files.
|
||||
all_lint_messages = []
|
||||
for filepath in args.filepaths:
|
||||
lint_messages = check_file(filepath)
|
||||
all_lint_messages.extend(lint_messages)
|
||||
|
||||
# Print out lint messages.
|
||||
for lint_message in all_lint_messages:
|
||||
print(json.dumps(lint_message._asdict()), flush=True)
|
||||
Loading…
Reference in a new issue