mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
To do https://github.com/pytorch/pytorch/pull/75972 in a lint free way I need to reformat all the imports (which are now incorrectly indented). This is a pain to do manually, so I plan to ask black to do it for me. But the files are not black compliant. So first reformat everything with black. This commit was generated with: ``` black tools/codegen ``` Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/76015 Approved by: https://github.com/bdhirsh
289 lines
9.5 KiB
Python
289 lines
9.5 KiB
Python
import contextlib
|
|
import functools
|
|
import hashlib
|
|
import os
|
|
import re
|
|
import textwrap
|
|
from argparse import Namespace
|
|
from typing import (
|
|
Tuple,
|
|
List,
|
|
Iterable,
|
|
Iterator,
|
|
Callable,
|
|
Sequence,
|
|
TypeVar,
|
|
Optional,
|
|
Dict,
|
|
Any,
|
|
Union,
|
|
Set,
|
|
NoReturn,
|
|
)
|
|
from enum import Enum
|
|
|
|
from tools.codegen.code_template import CodeTemplate
|
|
|
|
# Safely load fast C Yaml loader/dumper if they are available
|
|
try:
|
|
from yaml import CSafeLoader as Loader
|
|
except ImportError:
|
|
from yaml import SafeLoader as Loader # type: ignore[misc]
|
|
|
|
try:
|
|
from yaml import CSafeDumper as Dumper
|
|
except ImportError:
|
|
from yaml import SafeDumper as Dumper # type: ignore[misc]
|
|
YamlDumper = Dumper
|
|
|
|
# A custom loader for YAML that errors on duplicate keys.
|
|
# This doesn't happen by default: see https://github.com/yaml/pyyaml/issues/165
|
|
class YamlLoader(Loader):
|
|
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
|
|
mapping = []
|
|
for key_node, value_node in node.value:
|
|
key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call]
|
|
assert (
|
|
key not in mapping
|
|
), f"Found a duplicate key in the yaml. key={key}, line={node.start_mark.line}"
|
|
mapping.append(key)
|
|
mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call]
|
|
return mapping
|
|
|
|
|
|
# Many of these functions share logic for defining both the definition
|
|
# and declaration (for example, the function signature is the same), so
|
|
# we organize them into one function that takes a Target to say which
|
|
# code we want.
|
|
#
|
|
# This is an OPEN enum (we may add more cases to it in the future), so be sure
|
|
# to explicitly specify with Union[Literal[Target.XXX]] what targets are valid
|
|
# for your use.
|
|
Target = Enum(
|
|
"Target",
|
|
(
|
|
# top level namespace (not including at)
|
|
"DEFINITION",
|
|
"DECLARATION",
|
|
# TORCH_LIBRARY(...) { ... }
|
|
"REGISTRATION",
|
|
# namespace { ... }
|
|
"ANONYMOUS_DEFINITION",
|
|
# namespace cpu { ... }
|
|
"NAMESPACED_DEFINITION",
|
|
"NAMESPACED_DECLARATION",
|
|
),
|
|
)
|
|
|
|
# Matches "foo" in "foo, bar" but not "foobar". Used to search for the
|
|
# occurrence of a parameter in the derivative formula
|
|
IDENT_REGEX = r"(^|\W){}($|\W)"
|
|
|
|
# TODO: Use a real parser here; this will get bamboozled
|
|
def split_name_params(schema: str) -> Tuple[str, List[str]]:
|
|
m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
|
|
if m is None:
|
|
raise RuntimeError(f"Unsupported function schema: {schema}")
|
|
name, _, params = m.groups()
|
|
return name, params.split(", ")
|
|
|
|
|
|
T = TypeVar("T")
|
|
S = TypeVar("S")
|
|
|
|
# These two functions purposely return generators in analogy to map()
|
|
# so that you don't mix up when you need to list() them
|
|
|
|
# Map over function that may return None; omit Nones from output sequence
|
|
def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]:
|
|
for x in xs:
|
|
r = func(x)
|
|
if r is not None:
|
|
yield r
|
|
|
|
|
|
# Map over function that returns sequences and cat them all together
|
|
def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
|
|
for x in xs:
|
|
for r in func(x):
|
|
yield r
|
|
|
|
|
|
# Conveniently add error context to exceptions raised. Lets us
|
|
# easily say that an error occurred while processing a specific
|
|
# context.
|
|
@contextlib.contextmanager
|
|
def context(msg_fn: Callable[[], str]) -> Iterator[None]:
|
|
try:
|
|
yield
|
|
except Exception as e:
|
|
# TODO: this does the wrong thing with KeyError
|
|
msg = msg_fn()
|
|
msg = textwrap.indent(msg, " ")
|
|
msg = f"{e.args[0]}\n{msg}" if e.args else msg
|
|
e.args = (msg,) + e.args[1:]
|
|
raise
|
|
|
|
|
|
# A little trick from https://github.com/python/mypy/issues/6366
|
|
# for getting mypy to do exhaustiveness checking
|
|
# TODO: put this somewhere else, maybe
|
|
def assert_never(x: NoReturn) -> NoReturn:
|
|
raise AssertionError("Unhandled type: {}".format(type(x).__name__))
|
|
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
def _read_template(template_fn: str) -> CodeTemplate:
|
|
return CodeTemplate.from_file(template_fn)
|
|
|
|
|
|
# String hash that's stable across different executions, unlike builtin hash
|
|
def string_stable_hash(s: str) -> int:
|
|
sha1 = hashlib.sha1(s.encode("latin1")).digest()
|
|
return int.from_bytes(sha1, byteorder="little")
|
|
|
|
|
|
# A small abstraction for writing out generated files and keeping track
|
|
# of what files have been written (so you can write out a list of output
|
|
# files)
|
|
class FileManager:
|
|
install_dir: str
|
|
template_dir: str
|
|
dry_run: bool
|
|
filenames: Set[str]
|
|
|
|
def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
|
|
self.install_dir = install_dir
|
|
self.template_dir = template_dir
|
|
self.filenames = set()
|
|
self.dry_run = dry_run
|
|
|
|
def _write_if_changed(self, filename: str, contents: str) -> None:
|
|
old_contents: Optional[str]
|
|
try:
|
|
with open(filename, "r") as f:
|
|
old_contents = f.read()
|
|
except IOError:
|
|
old_contents = None
|
|
if contents != old_contents:
|
|
# Create output directory if it doesn't exist
|
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
|
with open(filename, "w") as f:
|
|
f.write(contents)
|
|
|
|
def write_with_template(
|
|
self,
|
|
filename: str,
|
|
template_fn: str,
|
|
env_callable: Callable[[], Union[str, Dict[str, Any]]],
|
|
) -> None:
|
|
filename = "{}/{}".format(self.install_dir, filename)
|
|
assert filename not in self.filenames, "duplicate file write {filename}"
|
|
self.filenames.add(filename)
|
|
if not self.dry_run:
|
|
env = env_callable()
|
|
if isinstance(env, dict):
|
|
# TODO: Update the comment reference to the correct location
|
|
if "generated_comment" not in env:
|
|
comment = "@" + "generated by tools/codegen/gen.py"
|
|
comment += " from {}".format(os.path.basename(template_fn))
|
|
env["generated_comment"] = comment
|
|
template = _read_template(os.path.join(self.template_dir, template_fn))
|
|
self._write_if_changed(filename, template.substitute(env))
|
|
elif isinstance(env, str):
|
|
self._write_if_changed(filename, env)
|
|
else:
|
|
assert_never(env)
|
|
|
|
def write(
|
|
self,
|
|
filename: str,
|
|
env_callable: Callable[[], Union[str, Union[str, Dict[str, Any]]]],
|
|
) -> None:
|
|
self.write_with_template(filename, filename, env_callable)
|
|
|
|
def write_sharded(
|
|
self,
|
|
filename: str,
|
|
items: Iterable[T],
|
|
*,
|
|
key_fn: Callable[[T], str],
|
|
env_callable: Callable[[T], Dict[str, List[str]]],
|
|
num_shards: int,
|
|
base_env: Optional[Dict[str, Any]] = None,
|
|
sharded_keys: Set[str],
|
|
) -> None:
|
|
|
|
everything: Dict[str, Any] = {"shard_id": "Everything"}
|
|
shards: List[Dict[str, Any]] = [
|
|
{"shard_id": f"_{i}"} for i in range(num_shards)
|
|
]
|
|
all_shards = [everything] + shards
|
|
|
|
if base_env is not None:
|
|
for shard in all_shards:
|
|
shard.update(base_env)
|
|
|
|
for key in sharded_keys:
|
|
for shard in all_shards:
|
|
if key in shard:
|
|
assert isinstance(
|
|
shard[key], list
|
|
), "sharded keys in base_env must be a list"
|
|
shard[key] = shard[key].copy()
|
|
else:
|
|
shard[key] = []
|
|
|
|
def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None:
|
|
for k, v in from_.items():
|
|
assert k in sharded_keys, f"undeclared sharded key {k}"
|
|
into[k] += v
|
|
|
|
if self.dry_run:
|
|
# Dry runs don't write any templates, so incomplete environments are fine
|
|
items = ()
|
|
|
|
for item in items:
|
|
key = key_fn(item)
|
|
sid = string_stable_hash(key) % num_shards
|
|
env = env_callable(item)
|
|
|
|
merge_env(shards[sid], env)
|
|
merge_env(everything, env)
|
|
|
|
dot_pos = filename.rfind(".")
|
|
if dot_pos == -1:
|
|
dot_pos = len(filename)
|
|
base_filename = filename[:dot_pos]
|
|
extension = filename[dot_pos:]
|
|
|
|
for shard in all_shards:
|
|
shard_id = shard["shard_id"]
|
|
self.write_with_template(
|
|
f"{base_filename}{shard_id}{extension}", filename, lambda: shard
|
|
)
|
|
|
|
# filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
|
|
self.filenames.discard(
|
|
f"{self.install_dir}/{base_filename}Everything{extension}"
|
|
)
|
|
|
|
def write_outputs(self, variable_name: str, filename: str) -> None:
|
|
"""Write a file containing the list of all outputs which are
|
|
generated by this script."""
|
|
content = "set({}\n {})".format(
|
|
variable_name,
|
|
"\n ".join('"' + name + '"' for name in sorted(self.filenames)),
|
|
)
|
|
self._write_if_changed(filename, content)
|
|
|
|
|
|
# Helper function to generate file manager
|
|
def make_file_manager(
|
|
options: Namespace, install_dir: Optional[str] = None
|
|
) -> FileManager:
|
|
template_dir = os.path.join(options.source_path, "templates")
|
|
install_dir = install_dir if install_dir else options.install_dir
|
|
return FileManager(
|
|
install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run
|
|
)
|