move some codegen utilities into utils.py (#63094)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63094

This PR:
- Moves `FileManager` and its dependencies (`assert_never` and other imports) to `utils.py`, and updates all of the call-sites with the fresh imports
- Passes the list of NativeFunction objects into `gen_trace_type` directly, instead of requiring the function to regenerate it (we already have it)

The purpose of the reshuffling is to avoid circular dependencies in the next PR, where I add codegen for the functionalization pass, which gets called from `gen.py` (but depends on some stuff from the autograd codegen - in partulcar, the list of view ops).

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D31942096

Pulled By: bdhirsh

fbshipit-source-id: 36118facae61f25f8922bb43ad2818c80b53504e
This commit is contained in:
Brian Hirsh 2021-10-28 10:43:11 -07:00 committed by Facebook GitHub Bot
parent b100a9ea82
commit 665c148e42
18 changed files with 171 additions and 166 deletions

View file

@ -20,7 +20,8 @@ import textwrap
from typing import Dict, List, Any
from tools.codegen.gen import parse_native_yaml, FileManager
from tools.codegen.gen import parse_native_yaml
from tools.codegen.utils import FileManager
from tools.codegen.context import with_native_function
from tools.codegen.model import BaseOperatorName, NativeFunction
import tools.codegen.api.python as python

View file

@ -66,7 +66,7 @@ def gen_autograd(
gen_inplace_or_view_type(out, native_functions_path, fns_with_diff_infos, template_path)
# operator filter not applied as tracing sources are excluded in selective build
gen_trace_type(out, native_functions_path, template_path)
gen_trace_type(out, native_funcs, template_path)
# Generate Functions.h/cpp
gen_autograd_functions_lib(
out, differentiability_infos, template_path)

View file

@ -15,7 +15,7 @@ from tools.codegen.api.types import (Binding, BaseCType, OptionalCType, tensorT,
doubleT, scalarT, stringT, boolT, intArrayRefT,
tensorListT, MutRefCType, ListCType, ArrayRefCType)
from tools.codegen.code_template import CodeTemplate
from tools.codegen.gen import FileManager
from tools.codegen.utils import FileManager
from tools.codegen.model import Argument
FUNCTION_DECLARATION = CodeTemplate("""\

View file

@ -18,8 +18,7 @@ from tools.codegen.model import (
SchemaKind, is_foreach_op,
)
from typing import List, Optional, Sequence, Tuple
from tools.codegen.gen import FileManager
from tools.codegen.utils import mapMaybe
from tools.codegen.utils import mapMaybe, FileManager
from .context import with_native_function_with_differentiability_info
from .gen_trace_type import (
MANUAL_AUTOGRAD, type_wrapper_name, tie_return_values, get_return_value

View file

@ -52,11 +52,11 @@ from tools.codegen.api.python import (PythonArgument, PythonSignature,
dispatch_lambda_return_str,
has_tensor_options,
namedtuple_fieldnames, signature)
from tools.codegen.gen import cpp_string, parse_native_yaml, FileManager
from tools.codegen.gen import cpp_string, parse_native_yaml
from tools.codegen.context import with_native_function
from tools.codegen.model import (Argument, BaseOperatorName, NativeFunction,
Type, Variant)
from tools.codegen.utils import split_name_params, YamlLoader
from tools.codegen.utils import split_name_params, YamlLoader, FileManager
from typing import Dict, Optional, List, Tuple, Set, Sequence, Callable

View file

@ -5,7 +5,7 @@ from tools.codegen.api.types import CppSignatureGroup, DispatcherSignature
from tools.codegen.api import cpp
from tools.codegen.code_template import CodeTemplate
from tools.codegen.context import with_native_function
from tools.codegen.gen import parse_native_yaml, FileManager
from tools.codegen.utils import FileManager
from tools.codegen.model import (Argument, NativeFunction, SchemaKind,
TensorOptionsArguments)
@ -405,11 +405,10 @@ def gen_trace_type_func(
'trace_wrapper_registrations': [method_registration(fn)],
}
def gen_trace_type(out: str, native_yaml_path: str, template_path: str) -> None:
def gen_trace_type(out: str, native_functions: List[NativeFunction], template_path: str) -> None:
# NOTE: see Note [Sharded File] at the top of the VariableType.cpp
# template regarding sharding of the generated files.
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
native_functions = parse_native_yaml(native_yaml_path).native_functions
fm.write_sharded(
'TraceType.cpp',
[fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER],

View file

@ -8,9 +8,9 @@ from typing import Optional, List
from tools.codegen.api.types import CppSignatureGroup
from tools.codegen.api import cpp
import tools.codegen.api.python as python
from tools.codegen.gen import parse_native_yaml, FileManager
from tools.codegen.gen import parse_native_yaml
from tools.codegen.context import with_native_function
from tools.codegen.utils import mapMaybe
from tools.codegen.utils import mapMaybe, FileManager
from tools.codegen.model import NativeFunction, TensorOptionsArguments, Variant
OPTIONAL_TYPE_PATTERN = re.compile(r"c10::optional<(.+)>")

View file

@ -47,8 +47,7 @@ from tools.codegen.api.autograd import (
from tools.codegen.api import cpp
from tools.codegen.code_template import CodeTemplate
from tools.codegen.context import native_function_manager, with_native_function
from tools.codegen.gen import FileManager
from tools.codegen.utils import mapMaybe
from tools.codegen.utils import mapMaybe, FileManager
from tools.codegen.model import (Argument, NativeFunction, SchemaKind,
SelfArgument, TensorOptionsArguments,
BaseType, ListType)

View file

@ -1,13 +1,14 @@
from tools.codegen.model import (Argument, Arguments, BaseTy, BaseType,
FunctionSchema, ListType, NativeFunction,
OptionalType, Return, SelfArgument,
TensorOptionsArguments, Type, assert_never)
TensorOptionsArguments, Type)
from tools.codegen.api.types import (ArgName, BaseCType, Binding, ConstRefCType, NamedCType, CType,
MutRefCType, ArrayCType, ListCType, VectorCType, ArrayRefCType,
OptionalCType, TupleCType, SpecialArgName, boolT, scalarT,
tensorListT, dimnameListT, tensorT, voidT,
BaseTypeToCppMapping, intArrayRefT, tensorOptionsT)
from tools.codegen import local
from tools.codegen.utils import assert_never
from typing import Optional, Sequence, Union, List, Set
# This file describes the translation of JIT schema to the public C++

View file

@ -1,10 +1,9 @@
from tools.codegen.model import (Argument, FunctionSchema, Return,
SelfArgument, TensorOptionsArguments, Type,
assert_never)
SelfArgument, TensorOptionsArguments, Type)
from tools.codegen.api.types import ArgName, Binding, NamedCType, CType
from tools.codegen.api import cpp
from tools.codegen.utils import concatMap
from tools.codegen.utils import concatMap, assert_never
import itertools
from typing import Sequence, List, Union

View file

@ -1,6 +1,5 @@
from tools.codegen.model import (Argument, FunctionSchema, Return,
SelfArgument, TensorOptionsArguments, Type,
assert_never)
SelfArgument, TensorOptionsArguments, Type)
from tools.codegen.api.types import (ArgName, BaseCType, Binding,
ConstRefCType, NamedCType, CType, MutRefCType, ListCType,
@ -8,6 +7,7 @@ from tools.codegen.api.types import (ArgName, BaseCType, Binding,
deviceT, boolT, scalarTypeT)
from tools.codegen.api import cpp
from tools.codegen import local
from tools.codegen.utils import assert_never
from typing import Union, Sequence, List, Optional

View file

@ -1,7 +1,6 @@
from tools.codegen.model import (Argument, BaseTy, BaseType, ListType,
NativeFunctionsGroup, OptionalType,
SelfArgument, TensorOptionsArguments, Type,
assert_never)
SelfArgument, TensorOptionsArguments, Type)
from tools.codegen.api.types import (ArgName, BaseCType, Binding, ArrayRefCType,
ConstRefCType, OptionalCType, NamedCType,
@ -9,6 +8,7 @@ from tools.codegen.api.types import (ArgName, BaseCType, Binding, ArrayRefCType,
optionalTensorRefT, optionalScalarRefT)
from tools.codegen.api import cpp
from tools.codegen.utils import assert_never
from typing import Union, List

View file

@ -5,11 +5,11 @@ from dataclasses import dataclass
import textwrap
from tools.codegen.context import method_with_native_function, native_function_manager
from tools.codegen.utils import Target, mapMaybe
from tools.codegen.utils import Target, mapMaybe, assert_never
from tools.codegen.model import (DispatchKey, NativeFunction,
NativeFunctionsGroup, SchemaKind,
TensorOptionsArguments,
DeviceCheckType, Argument, assert_never,
DeviceCheckType, Argument,
is_cuda_dispatch_key, BackendIndex,
gets_generated_out_inplace_wrapper)
from tools.codegen.api.types import (BaseCType, Binding, ConstRefCType,

View file

@ -1,23 +1,20 @@
import os
from typing import List, Dict, Optional, Tuple, Set, Callable, Any, Union, Sequence, TypeVar, Iterable
from typing import List, Dict, Optional, Tuple, Set, Any, Union, Sequence, TypeVar
from typing_extensions import Literal
import yaml
from collections import OrderedDict, defaultdict, namedtuple
import argparse
import pathlib
import functools
import json
from dataclasses import dataclass
import hashlib
from tools.codegen.code_template import CodeTemplate
from tools.codegen.model import (Argument, DispatchKey, FunctionSchema,
Location, NativeFunction,
NativeFunctionsGroup, OperatorName,
BackendIndex, BackendMetadata,
OptionalType, SchemaKind, SelfArgument,
TensorOptionsArguments, Type, Variant,
assert_never, is_cuda_dispatch_key,
is_cuda_dispatch_key,
is_generic_dispatch_key)
from tools.codegen.api.types import (Binding, CppSignature, CppSignatureGroup,
DispatcherSignature, NativeSignature)
@ -28,7 +25,9 @@ import tools.codegen.api.meta as meta
import tools.codegen.api.structured as structured
from tools.codegen.api.translate import translate
from tools.codegen.selective_build.selector import SelectiveBuilder
from tools.codegen.utils import Target, concatMap, context, mapMaybe, YamlDumper, YamlLoader
from tools.codegen.utils import (
Target, concatMap, context, mapMaybe, YamlDumper, YamlLoader, FileManager, assert_never
)
from tools.codegen.context import (method_with_native_function,
native_function_manager,
with_native_function_and_indices,
@ -884,131 +883,6 @@ def compute_registration_declarations(f: NativeFunction, backend_indices: Dict[D
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
@functools.lru_cache(maxsize=None)
def _read_template(template_fn: str) -> CodeTemplate:
return CodeTemplate.from_file(template_fn)
# String hash that's stable across different executions, unlike builtin hash
def string_stable_hash(s: str) -> int:
sha1 = hashlib.sha1(s.encode('latin1')).digest()
return int.from_bytes(sha1, byteorder='little')
# A small abstraction for writing out generated files and keeping track
# of what files have been written (so you can write out a list of output
# files)
class FileManager:
install_dir: str
template_dir: str
dry_run: bool
filenames: Set[str]
def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
self.install_dir = install_dir
self.template_dir = template_dir
self.filenames = set()
self.dry_run = dry_run
def _write_if_changed(self, filename: str, contents: str) -> None:
old_contents: Optional[str]
try:
with open(filename, 'r') as f:
old_contents = f.read()
except IOError:
old_contents = None
if contents != old_contents:
with open(filename, 'w') as f:
f.write(contents)
def write_with_template(self, filename: str, template_fn: str,
env_callable: Callable[[], Union[str, Dict[str, Any]]]) -> None:
filename = '{}/{}'.format(self.install_dir, filename)
assert filename not in self.filenames, "duplicate file write {filename}"
self.filenames.add(filename)
if not self.dry_run:
env = env_callable()
if isinstance(env, dict):
# TODO: Update the comment reference to the correct location
if 'generated_comment' not in env:
comment = "@" + "generated by tools/codegen/gen.py"
comment += " from {}".format(os.path.basename(template_fn))
env['generated_comment'] = comment
template = _read_template(os.path.join(self.template_dir, template_fn))
self._write_if_changed(filename, template.substitute(env))
elif isinstance(env, str):
self._write_if_changed(filename, env)
else:
assert_never(env)
def write(self, filename: str, env_callable: Callable[[], Union[str, Union[str, Dict[str, Any]]]]) -> None:
self.write_with_template(filename, filename, env_callable)
def write_sharded(
self,
filename: str,
items: Iterable[T],
*,
key_fn: Callable[[T], str],
env_callable: Callable[[T], Dict[str, List[str]]],
num_shards: int,
base_env: Optional[Dict[str, Any]] = None,
sharded_keys: Set[str]
) -> None:
everything: Dict[str, Any] = {'shard_id': 'Everything'}
shards: List[Dict[str, Any]] = [{'shard_id': f'_{i}'} for i in range(num_shards)]
all_shards = [everything] + shards
if base_env is not None:
for shard in all_shards:
shard.update(base_env)
for key in sharded_keys:
for shard in all_shards:
if key in shard:
assert isinstance(shard[key], list), "sharded keys in base_env must be a list"
shard[key] = shard[key].copy()
else:
shard[key] = []
def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None:
for k, v in from_.items():
assert k in sharded_keys, f"undeclared sharded key {k}"
into[k] += v
for item in items:
key = key_fn(item)
sid = string_stable_hash(key) % num_shards
env = env_callable(item)
merge_env(shards[sid], env)
merge_env(everything, env)
dot_pos = filename.rfind('.')
if dot_pos == -1:
dot_pos = len(filename)
base_filename = filename[:dot_pos]
extension = filename[dot_pos:]
for shard in all_shards:
shard_id = shard['shard_id']
self.write_with_template(f"{base_filename}{shard_id}{extension}",
filename,
lambda: shard)
# filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
self.filenames.discard(
f"{self.install_dir}/{base_filename}Everything{extension}")
def write_outputs(self, filename: str) -> None:
"""Write a file containing the list of all outputs which are
generated by this script."""
self._write_if_changed(
filename,
''.join(name + ";" for name in sorted(self.filenames)))
def get_custom_build_selector(
provided_op_registration_allowlist: Optional[List[str]],
op_selection_yaml_path: Optional[str]) -> SelectiveBuilder:

View file

@ -5,11 +5,11 @@ import yaml
import re
from collections import namedtuple, Counter, defaultdict
from typing import List, Dict, Union, Sequence, Optional
from tools.codegen.gen import FileManager, get_grouped_native_functions, parse_native_yaml
from tools.codegen.gen import get_grouped_native_functions, parse_native_yaml
from tools.codegen.model import (BackendIndex, BackendMetadata, DispatchKey,
NativeFunction, NativeFunctionsGroup, OperatorName)
from tools.codegen.selective_build.selector import SelectiveBuilder
from tools.codegen.utils import Target, concatMap, context, YamlLoader
from tools.codegen.utils import Target, concatMap, context, YamlLoader, FileManager
from tools.codegen.context import native_function_manager
import tools.codegen.dest as dest
import tools.codegen.api.dispatcher as dispatcher

View file

@ -1,16 +1,12 @@
import re
from tools.codegen.utils import assert_never
from dataclasses import dataclass
from typing import List, Dict, Optional, Iterator, Tuple, Set, NoReturn, Sequence, Callable, Union
from typing import List, Dict, Optional, Iterator, Tuple, Set, Sequence, Callable, Union
from enum import Enum, auto
import itertools
# A little trick from https://github.com/python/mypy/issues/6366
# for getting mypy to do exhaustiveness checking
# TODO: put this somewhere else, maybe
def assert_never(x: NoReturn) -> NoReturn:
raise AssertionError("Unhandled type: {}".format(type(x).__name__))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# DATA MODEL

View file

@ -1,8 +1,13 @@
import re
from typing import Tuple, List, Iterable, Iterator, Callable, Sequence, TypeVar, Optional
import os
from typing import Tuple, List, Iterable, Iterator, Callable, Sequence, TypeVar, Optional, Dict, Any, Union, Set, NoReturn
from enum import Enum
import contextlib
import textwrap
import hashlib
import functools
from tools.codegen.code_template import CodeTemplate
# Safely load fast C Yaml loader/dumper if they are available
try:
@ -94,3 +99,134 @@ def context(msg_fn: Callable[[], str]) -> Iterator[None]:
msg = f'{e.args[0]}\n{msg}' if e.args else msg
e.args = (msg,) + e.args[1:]
raise
# A little trick from https://github.com/python/mypy/issues/6366
# for getting mypy to do exhaustiveness checking
# TODO: put this somewhere else, maybe
def assert_never(x: NoReturn) -> NoReturn:
raise AssertionError("Unhandled type: {}".format(type(x).__name__))
@functools.lru_cache(maxsize=None)
def _read_template(template_fn: str) -> CodeTemplate:
return CodeTemplate.from_file(template_fn)
# String hash that's stable across different executions, unlike builtin hash
def string_stable_hash(s: str) -> int:
sha1 = hashlib.sha1(s.encode('latin1')).digest()
return int.from_bytes(sha1, byteorder='little')
# A small abstraction for writing out generated files and keeping track
# of what files have been written (so you can write out a list of output
# files)
class FileManager:
install_dir: str
template_dir: str
dry_run: bool
filenames: Set[str]
def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
self.install_dir = install_dir
self.template_dir = template_dir
self.filenames = set()
self.dry_run = dry_run
def _write_if_changed(self, filename: str, contents: str) -> None:
old_contents: Optional[str]
try:
with open(filename, 'r') as f:
old_contents = f.read()
except IOError:
old_contents = None
if contents != old_contents:
with open(filename, 'w') as f:
f.write(contents)
def write_with_template(self, filename: str, template_fn: str,
env_callable: Callable[[], Union[str, Dict[str, Any]]]) -> None:
filename = '{}/{}'.format(self.install_dir, filename)
assert filename not in self.filenames, "duplicate file write {filename}"
self.filenames.add(filename)
if not self.dry_run:
env = env_callable()
if isinstance(env, dict):
# TODO: Update the comment reference to the correct location
if 'generated_comment' not in env:
comment = "@" + "generated by tools/codegen/gen.py"
comment += " from {}".format(os.path.basename(template_fn))
env['generated_comment'] = comment
template = _read_template(os.path.join(self.template_dir, template_fn))
self._write_if_changed(filename, template.substitute(env))
elif isinstance(env, str):
self._write_if_changed(filename, env)
else:
assert_never(env)
def write(self, filename: str, env_callable: Callable[[], Union[str, Union[str, Dict[str, Any]]]]) -> None:
self.write_with_template(filename, filename, env_callable)
def write_sharded(
self,
filename: str,
items: Iterable[T],
*,
key_fn: Callable[[T], str],
env_callable: Callable[[T], Dict[str, List[str]]],
num_shards: int,
base_env: Optional[Dict[str, Any]] = None,
sharded_keys: Set[str]
) -> None:
everything: Dict[str, Any] = {'shard_id': 'Everything'}
shards: List[Dict[str, Any]] = [{'shard_id': f'_{i}'} for i in range(num_shards)]
all_shards = [everything] + shards
if base_env is not None:
for shard in all_shards:
shard.update(base_env)
for key in sharded_keys:
for shard in all_shards:
if key in shard:
assert isinstance(shard[key], list), "sharded keys in base_env must be a list"
shard[key] = shard[key].copy()
else:
shard[key] = []
def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None:
for k, v in from_.items():
assert k in sharded_keys, f"undeclared sharded key {k}"
into[k] += v
for item in items:
key = key_fn(item)
sid = string_stable_hash(key) % num_shards
env = env_callable(item)
merge_env(shards[sid], env)
merge_env(everything, env)
dot_pos = filename.rfind('.')
if dot_pos == -1:
dot_pos = len(filename)
base_filename = filename[:dot_pos]
extension = filename[dot_pos:]
for shard in all_shards:
shard_id = shard['shard_id']
self.write_with_template(f"{base_filename}{shard_id}{extension}",
filename,
lambda: shard)
# filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
self.filenames.discard(
f"{self.install_dir}/{base_filename}Everything{extension}")
def write_outputs(self, filename: str) -> None:
"""Write a file containing the list of all outputs which are
generated by this script."""
self._write_if_changed(
filename,
''.join(name + ";" for name in sorted(self.filenames)))

View file

@ -6,7 +6,8 @@ import argparse
from tools.codegen.model import Variant
from tools.codegen.api.python import (PythonSignatureGroup,
PythonSignatureNativeFunctionPair)
from tools.codegen.gen import FileManager, parse_native_yaml
from tools.codegen.gen import parse_native_yaml
from tools.codegen.utils import FileManager
from typing import Sequence, List, Dict
from ..autograd.gen_python_functions import should_generate_py_binding, load_signatures, group_overloads