mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
move some codegen utilities into utils.py (#63094)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63094 This PR: - Moves `FileManager` and its dependencies (`assert_never` and other imports) to `utils.py`, and updates all of the call-sites with the fresh imports - Passes the list of NativeFunction objects into `gen_trace_type` directly, instead of requiring the function to regenerate it (we already have it) The purpose of the reshuffling is to avoid circular dependencies in the next PR, where I add codegen for the functionalization pass, which gets called from `gen.py` (but depends on some stuff from the autograd codegen - in partulcar, the list of view ops). Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D31942096 Pulled By: bdhirsh fbshipit-source-id: 36118facae61f25f8922bb43ad2818c80b53504e
This commit is contained in:
parent
b100a9ea82
commit
665c148e42
18 changed files with 171 additions and 166 deletions
|
|
@ -20,7 +20,8 @@ import textwrap
|
|||
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from tools.codegen.gen import parse_native_yaml, FileManager
|
||||
from tools.codegen.gen import parse_native_yaml
|
||||
from tools.codegen.utils import FileManager
|
||||
from tools.codegen.context import with_native_function
|
||||
from tools.codegen.model import BaseOperatorName, NativeFunction
|
||||
import tools.codegen.api.python as python
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ def gen_autograd(
|
|||
gen_inplace_or_view_type(out, native_functions_path, fns_with_diff_infos, template_path)
|
||||
|
||||
# operator filter not applied as tracing sources are excluded in selective build
|
||||
gen_trace_type(out, native_functions_path, template_path)
|
||||
gen_trace_type(out, native_funcs, template_path)
|
||||
# Generate Functions.h/cpp
|
||||
gen_autograd_functions_lib(
|
||||
out, differentiability_infos, template_path)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from tools.codegen.api.types import (Binding, BaseCType, OptionalCType, tensorT,
|
|||
doubleT, scalarT, stringT, boolT, intArrayRefT,
|
||||
tensorListT, MutRefCType, ListCType, ArrayRefCType)
|
||||
from tools.codegen.code_template import CodeTemplate
|
||||
from tools.codegen.gen import FileManager
|
||||
from tools.codegen.utils import FileManager
|
||||
from tools.codegen.model import Argument
|
||||
|
||||
FUNCTION_DECLARATION = CodeTemplate("""\
|
||||
|
|
|
|||
|
|
@ -18,8 +18,7 @@ from tools.codegen.model import (
|
|||
SchemaKind, is_foreach_op,
|
||||
)
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
from tools.codegen.gen import FileManager
|
||||
from tools.codegen.utils import mapMaybe
|
||||
from tools.codegen.utils import mapMaybe, FileManager
|
||||
from .context import with_native_function_with_differentiability_info
|
||||
from .gen_trace_type import (
|
||||
MANUAL_AUTOGRAD, type_wrapper_name, tie_return_values, get_return_value
|
||||
|
|
|
|||
|
|
@ -52,11 +52,11 @@ from tools.codegen.api.python import (PythonArgument, PythonSignature,
|
|||
dispatch_lambda_return_str,
|
||||
has_tensor_options,
|
||||
namedtuple_fieldnames, signature)
|
||||
from tools.codegen.gen import cpp_string, parse_native_yaml, FileManager
|
||||
from tools.codegen.gen import cpp_string, parse_native_yaml
|
||||
from tools.codegen.context import with_native_function
|
||||
from tools.codegen.model import (Argument, BaseOperatorName, NativeFunction,
|
||||
Type, Variant)
|
||||
from tools.codegen.utils import split_name_params, YamlLoader
|
||||
from tools.codegen.utils import split_name_params, YamlLoader, FileManager
|
||||
|
||||
from typing import Dict, Optional, List, Tuple, Set, Sequence, Callable
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from tools.codegen.api.types import CppSignatureGroup, DispatcherSignature
|
|||
from tools.codegen.api import cpp
|
||||
from tools.codegen.code_template import CodeTemplate
|
||||
from tools.codegen.context import with_native_function
|
||||
from tools.codegen.gen import parse_native_yaml, FileManager
|
||||
from tools.codegen.utils import FileManager
|
||||
from tools.codegen.model import (Argument, NativeFunction, SchemaKind,
|
||||
TensorOptionsArguments)
|
||||
|
||||
|
|
@ -405,11 +405,10 @@ def gen_trace_type_func(
|
|||
'trace_wrapper_registrations': [method_registration(fn)],
|
||||
}
|
||||
|
||||
def gen_trace_type(out: str, native_yaml_path: str, template_path: str) -> None:
|
||||
def gen_trace_type(out: str, native_functions: List[NativeFunction], template_path: str) -> None:
|
||||
# NOTE: see Note [Sharded File] at the top of the VariableType.cpp
|
||||
# template regarding sharding of the generated files.
|
||||
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
|
||||
native_functions = parse_native_yaml(native_yaml_path).native_functions
|
||||
fm.write_sharded(
|
||||
'TraceType.cpp',
|
||||
[fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER],
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@ from typing import Optional, List
|
|||
from tools.codegen.api.types import CppSignatureGroup
|
||||
from tools.codegen.api import cpp
|
||||
import tools.codegen.api.python as python
|
||||
from tools.codegen.gen import parse_native_yaml, FileManager
|
||||
from tools.codegen.gen import parse_native_yaml
|
||||
from tools.codegen.context import with_native_function
|
||||
from tools.codegen.utils import mapMaybe
|
||||
from tools.codegen.utils import mapMaybe, FileManager
|
||||
from tools.codegen.model import NativeFunction, TensorOptionsArguments, Variant
|
||||
|
||||
OPTIONAL_TYPE_PATTERN = re.compile(r"c10::optional<(.+)>")
|
||||
|
|
|
|||
|
|
@ -47,8 +47,7 @@ from tools.codegen.api.autograd import (
|
|||
from tools.codegen.api import cpp
|
||||
from tools.codegen.code_template import CodeTemplate
|
||||
from tools.codegen.context import native_function_manager, with_native_function
|
||||
from tools.codegen.gen import FileManager
|
||||
from tools.codegen.utils import mapMaybe
|
||||
from tools.codegen.utils import mapMaybe, FileManager
|
||||
from tools.codegen.model import (Argument, NativeFunction, SchemaKind,
|
||||
SelfArgument, TensorOptionsArguments,
|
||||
BaseType, ListType)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
from tools.codegen.model import (Argument, Arguments, BaseTy, BaseType,
|
||||
FunctionSchema, ListType, NativeFunction,
|
||||
OptionalType, Return, SelfArgument,
|
||||
TensorOptionsArguments, Type, assert_never)
|
||||
TensorOptionsArguments, Type)
|
||||
from tools.codegen.api.types import (ArgName, BaseCType, Binding, ConstRefCType, NamedCType, CType,
|
||||
MutRefCType, ArrayCType, ListCType, VectorCType, ArrayRefCType,
|
||||
OptionalCType, TupleCType, SpecialArgName, boolT, scalarT,
|
||||
tensorListT, dimnameListT, tensorT, voidT,
|
||||
BaseTypeToCppMapping, intArrayRefT, tensorOptionsT)
|
||||
from tools.codegen import local
|
||||
from tools.codegen.utils import assert_never
|
||||
from typing import Optional, Sequence, Union, List, Set
|
||||
|
||||
# This file describes the translation of JIT schema to the public C++
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
from tools.codegen.model import (Argument, FunctionSchema, Return,
|
||||
SelfArgument, TensorOptionsArguments, Type,
|
||||
assert_never)
|
||||
SelfArgument, TensorOptionsArguments, Type)
|
||||
|
||||
from tools.codegen.api.types import ArgName, Binding, NamedCType, CType
|
||||
from tools.codegen.api import cpp
|
||||
from tools.codegen.utils import concatMap
|
||||
from tools.codegen.utils import concatMap, assert_never
|
||||
|
||||
import itertools
|
||||
from typing import Sequence, List, Union
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from tools.codegen.model import (Argument, FunctionSchema, Return,
|
||||
SelfArgument, TensorOptionsArguments, Type,
|
||||
assert_never)
|
||||
SelfArgument, TensorOptionsArguments, Type)
|
||||
|
||||
from tools.codegen.api.types import (ArgName, BaseCType, Binding,
|
||||
ConstRefCType, NamedCType, CType, MutRefCType, ListCType,
|
||||
|
|
@ -8,6 +7,7 @@ from tools.codegen.api.types import (ArgName, BaseCType, Binding,
|
|||
deviceT, boolT, scalarTypeT)
|
||||
from tools.codegen.api import cpp
|
||||
from tools.codegen import local
|
||||
from tools.codegen.utils import assert_never
|
||||
|
||||
from typing import Union, Sequence, List, Optional
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from tools.codegen.model import (Argument, BaseTy, BaseType, ListType,
|
||||
NativeFunctionsGroup, OptionalType,
|
||||
SelfArgument, TensorOptionsArguments, Type,
|
||||
assert_never)
|
||||
SelfArgument, TensorOptionsArguments, Type)
|
||||
|
||||
from tools.codegen.api.types import (ArgName, BaseCType, Binding, ArrayRefCType,
|
||||
ConstRefCType, OptionalCType, NamedCType,
|
||||
|
|
@ -9,6 +8,7 @@ from tools.codegen.api.types import (ArgName, BaseCType, Binding, ArrayRefCType,
|
|||
optionalTensorRefT, optionalScalarRefT)
|
||||
|
||||
from tools.codegen.api import cpp
|
||||
from tools.codegen.utils import assert_never
|
||||
|
||||
from typing import Union, List
|
||||
|
||||
|
|
|
|||
|
|
@ -5,11 +5,11 @@ from dataclasses import dataclass
|
|||
import textwrap
|
||||
|
||||
from tools.codegen.context import method_with_native_function, native_function_manager
|
||||
from tools.codegen.utils import Target, mapMaybe
|
||||
from tools.codegen.utils import Target, mapMaybe, assert_never
|
||||
from tools.codegen.model import (DispatchKey, NativeFunction,
|
||||
NativeFunctionsGroup, SchemaKind,
|
||||
TensorOptionsArguments,
|
||||
DeviceCheckType, Argument, assert_never,
|
||||
DeviceCheckType, Argument,
|
||||
is_cuda_dispatch_key, BackendIndex,
|
||||
gets_generated_out_inplace_wrapper)
|
||||
from tools.codegen.api.types import (BaseCType, Binding, ConstRefCType,
|
||||
|
|
|
|||
|
|
@ -1,23 +1,20 @@
|
|||
import os
|
||||
from typing import List, Dict, Optional, Tuple, Set, Callable, Any, Union, Sequence, TypeVar, Iterable
|
||||
from typing import List, Dict, Optional, Tuple, Set, Any, Union, Sequence, TypeVar
|
||||
from typing_extensions import Literal
|
||||
import yaml
|
||||
from collections import OrderedDict, defaultdict, namedtuple
|
||||
import argparse
|
||||
import pathlib
|
||||
import functools
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
import hashlib
|
||||
|
||||
from tools.codegen.code_template import CodeTemplate
|
||||
from tools.codegen.model import (Argument, DispatchKey, FunctionSchema,
|
||||
Location, NativeFunction,
|
||||
NativeFunctionsGroup, OperatorName,
|
||||
BackendIndex, BackendMetadata,
|
||||
OptionalType, SchemaKind, SelfArgument,
|
||||
TensorOptionsArguments, Type, Variant,
|
||||
assert_never, is_cuda_dispatch_key,
|
||||
is_cuda_dispatch_key,
|
||||
is_generic_dispatch_key)
|
||||
from tools.codegen.api.types import (Binding, CppSignature, CppSignatureGroup,
|
||||
DispatcherSignature, NativeSignature)
|
||||
|
|
@ -28,7 +25,9 @@ import tools.codegen.api.meta as meta
|
|||
import tools.codegen.api.structured as structured
|
||||
from tools.codegen.api.translate import translate
|
||||
from tools.codegen.selective_build.selector import SelectiveBuilder
|
||||
from tools.codegen.utils import Target, concatMap, context, mapMaybe, YamlDumper, YamlLoader
|
||||
from tools.codegen.utils import (
|
||||
Target, concatMap, context, mapMaybe, YamlDumper, YamlLoader, FileManager, assert_never
|
||||
)
|
||||
from tools.codegen.context import (method_with_native_function,
|
||||
native_function_manager,
|
||||
with_native_function_and_indices,
|
||||
|
|
@ -884,131 +883,6 @@ def compute_registration_declarations(f: NativeFunction, backend_indices: Dict[D
|
|||
#
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _read_template(template_fn: str) -> CodeTemplate:
|
||||
return CodeTemplate.from_file(template_fn)
|
||||
|
||||
|
||||
# String hash that's stable across different executions, unlike builtin hash
|
||||
def string_stable_hash(s: str) -> int:
|
||||
sha1 = hashlib.sha1(s.encode('latin1')).digest()
|
||||
return int.from_bytes(sha1, byteorder='little')
|
||||
|
||||
# A small abstraction for writing out generated files and keeping track
|
||||
# of what files have been written (so you can write out a list of output
|
||||
# files)
|
||||
class FileManager:
|
||||
install_dir: str
|
||||
template_dir: str
|
||||
dry_run: bool
|
||||
filenames: Set[str]
|
||||
|
||||
def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
|
||||
self.install_dir = install_dir
|
||||
self.template_dir = template_dir
|
||||
self.filenames = set()
|
||||
self.dry_run = dry_run
|
||||
|
||||
def _write_if_changed(self, filename: str, contents: str) -> None:
|
||||
old_contents: Optional[str]
|
||||
try:
|
||||
with open(filename, 'r') as f:
|
||||
old_contents = f.read()
|
||||
except IOError:
|
||||
old_contents = None
|
||||
if contents != old_contents:
|
||||
with open(filename, 'w') as f:
|
||||
f.write(contents)
|
||||
|
||||
def write_with_template(self, filename: str, template_fn: str,
|
||||
env_callable: Callable[[], Union[str, Dict[str, Any]]]) -> None:
|
||||
filename = '{}/{}'.format(self.install_dir, filename)
|
||||
assert filename not in self.filenames, "duplicate file write {filename}"
|
||||
self.filenames.add(filename)
|
||||
if not self.dry_run:
|
||||
env = env_callable()
|
||||
if isinstance(env, dict):
|
||||
# TODO: Update the comment reference to the correct location
|
||||
if 'generated_comment' not in env:
|
||||
comment = "@" + "generated by tools/codegen/gen.py"
|
||||
comment += " from {}".format(os.path.basename(template_fn))
|
||||
env['generated_comment'] = comment
|
||||
template = _read_template(os.path.join(self.template_dir, template_fn))
|
||||
self._write_if_changed(filename, template.substitute(env))
|
||||
elif isinstance(env, str):
|
||||
self._write_if_changed(filename, env)
|
||||
else:
|
||||
assert_never(env)
|
||||
|
||||
|
||||
def write(self, filename: str, env_callable: Callable[[], Union[str, Union[str, Dict[str, Any]]]]) -> None:
|
||||
self.write_with_template(filename, filename, env_callable)
|
||||
|
||||
def write_sharded(
|
||||
self,
|
||||
filename: str,
|
||||
items: Iterable[T],
|
||||
*,
|
||||
key_fn: Callable[[T], str],
|
||||
env_callable: Callable[[T], Dict[str, List[str]]],
|
||||
num_shards: int,
|
||||
base_env: Optional[Dict[str, Any]] = None,
|
||||
sharded_keys: Set[str]
|
||||
) -> None:
|
||||
|
||||
everything: Dict[str, Any] = {'shard_id': 'Everything'}
|
||||
shards: List[Dict[str, Any]] = [{'shard_id': f'_{i}'} for i in range(num_shards)]
|
||||
all_shards = [everything] + shards
|
||||
|
||||
if base_env is not None:
|
||||
for shard in all_shards:
|
||||
shard.update(base_env)
|
||||
|
||||
for key in sharded_keys:
|
||||
for shard in all_shards:
|
||||
if key in shard:
|
||||
assert isinstance(shard[key], list), "sharded keys in base_env must be a list"
|
||||
shard[key] = shard[key].copy()
|
||||
else:
|
||||
shard[key] = []
|
||||
|
||||
|
||||
def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None:
|
||||
for k, v in from_.items():
|
||||
assert k in sharded_keys, f"undeclared sharded key {k}"
|
||||
into[k] += v
|
||||
|
||||
for item in items:
|
||||
key = key_fn(item)
|
||||
sid = string_stable_hash(key) % num_shards
|
||||
env = env_callable(item)
|
||||
|
||||
merge_env(shards[sid], env)
|
||||
merge_env(everything, env)
|
||||
|
||||
dot_pos = filename.rfind('.')
|
||||
if dot_pos == -1:
|
||||
dot_pos = len(filename)
|
||||
base_filename = filename[:dot_pos]
|
||||
extension = filename[dot_pos:]
|
||||
|
||||
for shard in all_shards:
|
||||
shard_id = shard['shard_id']
|
||||
self.write_with_template(f"{base_filename}{shard_id}{extension}",
|
||||
filename,
|
||||
lambda: shard)
|
||||
|
||||
# filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
|
||||
self.filenames.discard(
|
||||
f"{self.install_dir}/{base_filename}Everything{extension}")
|
||||
|
||||
def write_outputs(self, filename: str) -> None:
|
||||
"""Write a file containing the list of all outputs which are
|
||||
generated by this script."""
|
||||
self._write_if_changed(
|
||||
filename,
|
||||
''.join(name + ";" for name in sorted(self.filenames)))
|
||||
|
||||
def get_custom_build_selector(
|
||||
provided_op_registration_allowlist: Optional[List[str]],
|
||||
op_selection_yaml_path: Optional[str]) -> SelectiveBuilder:
|
||||
|
|
|
|||
|
|
@ -5,11 +5,11 @@ import yaml
|
|||
import re
|
||||
from collections import namedtuple, Counter, defaultdict
|
||||
from typing import List, Dict, Union, Sequence, Optional
|
||||
from tools.codegen.gen import FileManager, get_grouped_native_functions, parse_native_yaml
|
||||
from tools.codegen.gen import get_grouped_native_functions, parse_native_yaml
|
||||
from tools.codegen.model import (BackendIndex, BackendMetadata, DispatchKey,
|
||||
NativeFunction, NativeFunctionsGroup, OperatorName)
|
||||
from tools.codegen.selective_build.selector import SelectiveBuilder
|
||||
from tools.codegen.utils import Target, concatMap, context, YamlLoader
|
||||
from tools.codegen.utils import Target, concatMap, context, YamlLoader, FileManager
|
||||
from tools.codegen.context import native_function_manager
|
||||
import tools.codegen.dest as dest
|
||||
import tools.codegen.api.dispatcher as dispatcher
|
||||
|
|
|
|||
|
|
@ -1,16 +1,12 @@
|
|||
import re
|
||||
|
||||
from tools.codegen.utils import assert_never
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Optional, Iterator, Tuple, Set, NoReturn, Sequence, Callable, Union
|
||||
from typing import List, Dict, Optional, Iterator, Tuple, Set, Sequence, Callable, Union
|
||||
from enum import Enum, auto
|
||||
import itertools
|
||||
|
||||
# A little trick from https://github.com/python/mypy/issues/6366
|
||||
# for getting mypy to do exhaustiveness checking
|
||||
# TODO: put this somewhere else, maybe
|
||||
def assert_never(x: NoReturn) -> NoReturn:
|
||||
raise AssertionError("Unhandled type: {}".format(type(x).__name__))
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
#
|
||||
# DATA MODEL
|
||||
|
|
|
|||
|
|
@ -1,8 +1,13 @@
|
|||
import re
|
||||
from typing import Tuple, List, Iterable, Iterator, Callable, Sequence, TypeVar, Optional
|
||||
import os
|
||||
from typing import Tuple, List, Iterable, Iterator, Callable, Sequence, TypeVar, Optional, Dict, Any, Union, Set, NoReturn
|
||||
from enum import Enum
|
||||
import contextlib
|
||||
import textwrap
|
||||
import hashlib
|
||||
import functools
|
||||
|
||||
from tools.codegen.code_template import CodeTemplate
|
||||
|
||||
# Safely load fast C Yaml loader/dumper if they are available
|
||||
try:
|
||||
|
|
@ -94,3 +99,134 @@ def context(msg_fn: Callable[[], str]) -> Iterator[None]:
|
|||
msg = f'{e.args[0]}\n{msg}' if e.args else msg
|
||||
e.args = (msg,) + e.args[1:]
|
||||
raise
|
||||
|
||||
# A little trick from https://github.com/python/mypy/issues/6366
|
||||
# for getting mypy to do exhaustiveness checking
|
||||
# TODO: put this somewhere else, maybe
|
||||
def assert_never(x: NoReturn) -> NoReturn:
|
||||
raise AssertionError("Unhandled type: {}".format(type(x).__name__))
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _read_template(template_fn: str) -> CodeTemplate:
|
||||
return CodeTemplate.from_file(template_fn)
|
||||
|
||||
|
||||
# String hash that's stable across different executions, unlike builtin hash
|
||||
def string_stable_hash(s: str) -> int:
|
||||
sha1 = hashlib.sha1(s.encode('latin1')).digest()
|
||||
return int.from_bytes(sha1, byteorder='little')
|
||||
|
||||
# A small abstraction for writing out generated files and keeping track
|
||||
# of what files have been written (so you can write out a list of output
|
||||
# files)
|
||||
class FileManager:
|
||||
install_dir: str
|
||||
template_dir: str
|
||||
dry_run: bool
|
||||
filenames: Set[str]
|
||||
|
||||
def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
|
||||
self.install_dir = install_dir
|
||||
self.template_dir = template_dir
|
||||
self.filenames = set()
|
||||
self.dry_run = dry_run
|
||||
|
||||
def _write_if_changed(self, filename: str, contents: str) -> None:
|
||||
old_contents: Optional[str]
|
||||
try:
|
||||
with open(filename, 'r') as f:
|
||||
old_contents = f.read()
|
||||
except IOError:
|
||||
old_contents = None
|
||||
if contents != old_contents:
|
||||
with open(filename, 'w') as f:
|
||||
f.write(contents)
|
||||
|
||||
def write_with_template(self, filename: str, template_fn: str,
|
||||
env_callable: Callable[[], Union[str, Dict[str, Any]]]) -> None:
|
||||
filename = '{}/{}'.format(self.install_dir, filename)
|
||||
assert filename not in self.filenames, "duplicate file write {filename}"
|
||||
self.filenames.add(filename)
|
||||
if not self.dry_run:
|
||||
env = env_callable()
|
||||
if isinstance(env, dict):
|
||||
# TODO: Update the comment reference to the correct location
|
||||
if 'generated_comment' not in env:
|
||||
comment = "@" + "generated by tools/codegen/gen.py"
|
||||
comment += " from {}".format(os.path.basename(template_fn))
|
||||
env['generated_comment'] = comment
|
||||
template = _read_template(os.path.join(self.template_dir, template_fn))
|
||||
self._write_if_changed(filename, template.substitute(env))
|
||||
elif isinstance(env, str):
|
||||
self._write_if_changed(filename, env)
|
||||
else:
|
||||
assert_never(env)
|
||||
|
||||
|
||||
def write(self, filename: str, env_callable: Callable[[], Union[str, Union[str, Dict[str, Any]]]]) -> None:
|
||||
self.write_with_template(filename, filename, env_callable)
|
||||
|
||||
def write_sharded(
|
||||
self,
|
||||
filename: str,
|
||||
items: Iterable[T],
|
||||
*,
|
||||
key_fn: Callable[[T], str],
|
||||
env_callable: Callable[[T], Dict[str, List[str]]],
|
||||
num_shards: int,
|
||||
base_env: Optional[Dict[str, Any]] = None,
|
||||
sharded_keys: Set[str]
|
||||
) -> None:
|
||||
|
||||
everything: Dict[str, Any] = {'shard_id': 'Everything'}
|
||||
shards: List[Dict[str, Any]] = [{'shard_id': f'_{i}'} for i in range(num_shards)]
|
||||
all_shards = [everything] + shards
|
||||
|
||||
if base_env is not None:
|
||||
for shard in all_shards:
|
||||
shard.update(base_env)
|
||||
|
||||
for key in sharded_keys:
|
||||
for shard in all_shards:
|
||||
if key in shard:
|
||||
assert isinstance(shard[key], list), "sharded keys in base_env must be a list"
|
||||
shard[key] = shard[key].copy()
|
||||
else:
|
||||
shard[key] = []
|
||||
|
||||
|
||||
def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None:
|
||||
for k, v in from_.items():
|
||||
assert k in sharded_keys, f"undeclared sharded key {k}"
|
||||
into[k] += v
|
||||
|
||||
for item in items:
|
||||
key = key_fn(item)
|
||||
sid = string_stable_hash(key) % num_shards
|
||||
env = env_callable(item)
|
||||
|
||||
merge_env(shards[sid], env)
|
||||
merge_env(everything, env)
|
||||
|
||||
dot_pos = filename.rfind('.')
|
||||
if dot_pos == -1:
|
||||
dot_pos = len(filename)
|
||||
base_filename = filename[:dot_pos]
|
||||
extension = filename[dot_pos:]
|
||||
|
||||
for shard in all_shards:
|
||||
shard_id = shard['shard_id']
|
||||
self.write_with_template(f"{base_filename}{shard_id}{extension}",
|
||||
filename,
|
||||
lambda: shard)
|
||||
|
||||
# filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
|
||||
self.filenames.discard(
|
||||
f"{self.install_dir}/{base_filename}Everything{extension}")
|
||||
|
||||
def write_outputs(self, filename: str) -> None:
|
||||
"""Write a file containing the list of all outputs which are
|
||||
generated by this script."""
|
||||
self._write_if_changed(
|
||||
filename,
|
||||
''.join(name + ";" for name in sorted(self.filenames)))
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ import argparse
|
|||
from tools.codegen.model import Variant
|
||||
from tools.codegen.api.python import (PythonSignatureGroup,
|
||||
PythonSignatureNativeFunctionPair)
|
||||
from tools.codegen.gen import FileManager, parse_native_yaml
|
||||
from tools.codegen.gen import parse_native_yaml
|
||||
from tools.codegen.utils import FileManager
|
||||
from typing import Sequence, List, Dict
|
||||
|
||||
from ..autograd.gen_python_functions import should_generate_py_binding, load_signatures, group_overloads
|
||||
|
|
|
|||
Loading…
Reference in a new issue