From 665c148e423bf36a06b45a1b52b3ec68403d665c Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Thu, 28 Oct 2021 10:43:11 -0700 Subject: [PATCH] move some codegen utilities into utils.py (#63094) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63094 This PR: - Moves `FileManager` and its dependencies (`assert_never` and other imports) to `utils.py`, and updates all of the call-sites with the fresh imports - Passes the list of NativeFunction objects into `gen_trace_type` directly, instead of requiring the function to regenerate it (we already have it) The purpose of the reshuffling is to avoid circular dependencies in the next PR, where I add codegen for the functionalization pass, which gets called from `gen.py` (but depends on some stuff from the autograd codegen - in partulcar, the list of view ops). Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D31942096 Pulled By: bdhirsh fbshipit-source-id: 36118facae61f25f8922bb43ad2818c80b53504e --- tools/autograd/gen_annotated_fn_args.py | 3 +- tools/autograd/gen_autograd.py | 2 +- tools/autograd/gen_autograd_functions.py | 2 +- tools/autograd/gen_inplace_or_view_type.py | 3 +- tools/autograd/gen_python_functions.py | 4 +- tools/autograd/gen_trace_type.py | 5 +- tools/autograd/gen_variable_factories.py | 4 +- tools/autograd/gen_variable_type.py | 3 +- tools/codegen/api/cpp.py | 3 +- tools/codegen/api/dispatcher.py | 5 +- tools/codegen/api/native.py | 4 +- tools/codegen/api/structured.py | 4 +- tools/codegen/dest/register_dispatch_key.py | 4 +- tools/codegen/gen.py | 136 +------------------ tools/codegen/gen_backend_stubs.py | 4 +- tools/codegen/model.py | 10 +- tools/codegen/utils.py | 138 +++++++++++++++++++- tools/pyi/gen_pyi.py | 3 +- 18 files changed, 171 insertions(+), 166 deletions(-) diff --git a/tools/autograd/gen_annotated_fn_args.py b/tools/autograd/gen_annotated_fn_args.py index 44f7c14438a..2d1dbd5c71a 100644 --- a/tools/autograd/gen_annotated_fn_args.py +++ b/tools/autograd/gen_annotated_fn_args.py @@ -20,7 +20,8 @@ import textwrap from typing import Dict, List, Any -from tools.codegen.gen import parse_native_yaml, FileManager +from tools.codegen.gen import parse_native_yaml +from tools.codegen.utils import FileManager from tools.codegen.context import with_native_function from tools.codegen.model import BaseOperatorName, NativeFunction import tools.codegen.api.python as python diff --git a/tools/autograd/gen_autograd.py b/tools/autograd/gen_autograd.py index 5f0d7eca00e..5c70be3df59 100644 --- a/tools/autograd/gen_autograd.py +++ b/tools/autograd/gen_autograd.py @@ -66,7 +66,7 @@ def gen_autograd( gen_inplace_or_view_type(out, native_functions_path, fns_with_diff_infos, template_path) # operator filter not applied as tracing sources are excluded in selective build - gen_trace_type(out, native_functions_path, template_path) + gen_trace_type(out, native_funcs, template_path) # Generate Functions.h/cpp gen_autograd_functions_lib( out, differentiability_infos, template_path) diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index 0e79eb16a65..be7c7212db8 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -15,7 +15,7 @@ from tools.codegen.api.types import (Binding, BaseCType, OptionalCType, tensorT, doubleT, scalarT, stringT, boolT, intArrayRefT, tensorListT, MutRefCType, ListCType, ArrayRefCType) from tools.codegen.code_template import CodeTemplate -from tools.codegen.gen import FileManager +from tools.codegen.utils import FileManager from tools.codegen.model import Argument FUNCTION_DECLARATION = CodeTemplate("""\ diff --git a/tools/autograd/gen_inplace_or_view_type.py b/tools/autograd/gen_inplace_or_view_type.py index f49f79a5ad1..09ac645bc05 100644 --- a/tools/autograd/gen_inplace_or_view_type.py +++ b/tools/autograd/gen_inplace_or_view_type.py @@ -18,8 +18,7 @@ from tools.codegen.model import ( SchemaKind, is_foreach_op, ) from typing import List, Optional, Sequence, Tuple -from tools.codegen.gen import FileManager -from tools.codegen.utils import mapMaybe +from tools.codegen.utils import mapMaybe, FileManager from .context import with_native_function_with_differentiability_info from .gen_trace_type import ( MANUAL_AUTOGRAD, type_wrapper_name, tie_return_values, get_return_value diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 24cd9da2284..917a8a9f80c 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -52,11 +52,11 @@ from tools.codegen.api.python import (PythonArgument, PythonSignature, dispatch_lambda_return_str, has_tensor_options, namedtuple_fieldnames, signature) -from tools.codegen.gen import cpp_string, parse_native_yaml, FileManager +from tools.codegen.gen import cpp_string, parse_native_yaml from tools.codegen.context import with_native_function from tools.codegen.model import (Argument, BaseOperatorName, NativeFunction, Type, Variant) -from tools.codegen.utils import split_name_params, YamlLoader +from tools.codegen.utils import split_name_params, YamlLoader, FileManager from typing import Dict, Optional, List, Tuple, Set, Sequence, Callable diff --git a/tools/autograd/gen_trace_type.py b/tools/autograd/gen_trace_type.py index 8834c486f26..4183d3c494d 100644 --- a/tools/autograd/gen_trace_type.py +++ b/tools/autograd/gen_trace_type.py @@ -5,7 +5,7 @@ from tools.codegen.api.types import CppSignatureGroup, DispatcherSignature from tools.codegen.api import cpp from tools.codegen.code_template import CodeTemplate from tools.codegen.context import with_native_function -from tools.codegen.gen import parse_native_yaml, FileManager +from tools.codegen.utils import FileManager from tools.codegen.model import (Argument, NativeFunction, SchemaKind, TensorOptionsArguments) @@ -405,11 +405,10 @@ def gen_trace_type_func( 'trace_wrapper_registrations': [method_registration(fn)], } -def gen_trace_type(out: str, native_yaml_path: str, template_path: str) -> None: +def gen_trace_type(out: str, native_functions: List[NativeFunction], template_path: str) -> None: # NOTE: see Note [Sharded File] at the top of the VariableType.cpp # template regarding sharding of the generated files. fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) - native_functions = parse_native_yaml(native_yaml_path).native_functions fm.write_sharded( 'TraceType.cpp', [fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER], diff --git a/tools/autograd/gen_variable_factories.py b/tools/autograd/gen_variable_factories.py index 5f5eef25178..ebaa7a29c8b 100644 --- a/tools/autograd/gen_variable_factories.py +++ b/tools/autograd/gen_variable_factories.py @@ -8,9 +8,9 @@ from typing import Optional, List from tools.codegen.api.types import CppSignatureGroup from tools.codegen.api import cpp import tools.codegen.api.python as python -from tools.codegen.gen import parse_native_yaml, FileManager +from tools.codegen.gen import parse_native_yaml from tools.codegen.context import with_native_function -from tools.codegen.utils import mapMaybe +from tools.codegen.utils import mapMaybe, FileManager from tools.codegen.model import NativeFunction, TensorOptionsArguments, Variant OPTIONAL_TYPE_PATTERN = re.compile(r"c10::optional<(.+)>") diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 36480b26411..8fdb1870c0d 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -47,8 +47,7 @@ from tools.codegen.api.autograd import ( from tools.codegen.api import cpp from tools.codegen.code_template import CodeTemplate from tools.codegen.context import native_function_manager, with_native_function -from tools.codegen.gen import FileManager -from tools.codegen.utils import mapMaybe +from tools.codegen.utils import mapMaybe, FileManager from tools.codegen.model import (Argument, NativeFunction, SchemaKind, SelfArgument, TensorOptionsArguments, BaseType, ListType) diff --git a/tools/codegen/api/cpp.py b/tools/codegen/api/cpp.py index 54ae4670cf0..8e1427f479f 100644 --- a/tools/codegen/api/cpp.py +++ b/tools/codegen/api/cpp.py @@ -1,13 +1,14 @@ from tools.codegen.model import (Argument, Arguments, BaseTy, BaseType, FunctionSchema, ListType, NativeFunction, OptionalType, Return, SelfArgument, - TensorOptionsArguments, Type, assert_never) + TensorOptionsArguments, Type) from tools.codegen.api.types import (ArgName, BaseCType, Binding, ConstRefCType, NamedCType, CType, MutRefCType, ArrayCType, ListCType, VectorCType, ArrayRefCType, OptionalCType, TupleCType, SpecialArgName, boolT, scalarT, tensorListT, dimnameListT, tensorT, voidT, BaseTypeToCppMapping, intArrayRefT, tensorOptionsT) from tools.codegen import local +from tools.codegen.utils import assert_never from typing import Optional, Sequence, Union, List, Set # This file describes the translation of JIT schema to the public C++ diff --git a/tools/codegen/api/dispatcher.py b/tools/codegen/api/dispatcher.py index be51c4adbc4..a870114596d 100644 --- a/tools/codegen/api/dispatcher.py +++ b/tools/codegen/api/dispatcher.py @@ -1,10 +1,9 @@ from tools.codegen.model import (Argument, FunctionSchema, Return, - SelfArgument, TensorOptionsArguments, Type, - assert_never) + SelfArgument, TensorOptionsArguments, Type) from tools.codegen.api.types import ArgName, Binding, NamedCType, CType from tools.codegen.api import cpp -from tools.codegen.utils import concatMap +from tools.codegen.utils import concatMap, assert_never import itertools from typing import Sequence, List, Union diff --git a/tools/codegen/api/native.py b/tools/codegen/api/native.py index d2bd01a8e5c..d072f20d427 100644 --- a/tools/codegen/api/native.py +++ b/tools/codegen/api/native.py @@ -1,6 +1,5 @@ from tools.codegen.model import (Argument, FunctionSchema, Return, - SelfArgument, TensorOptionsArguments, Type, - assert_never) + SelfArgument, TensorOptionsArguments, Type) from tools.codegen.api.types import (ArgName, BaseCType, Binding, ConstRefCType, NamedCType, CType, MutRefCType, ListCType, @@ -8,6 +7,7 @@ from tools.codegen.api.types import (ArgName, BaseCType, Binding, deviceT, boolT, scalarTypeT) from tools.codegen.api import cpp from tools.codegen import local +from tools.codegen.utils import assert_never from typing import Union, Sequence, List, Optional diff --git a/tools/codegen/api/structured.py b/tools/codegen/api/structured.py index 6aab794413c..664c6ad01f6 100644 --- a/tools/codegen/api/structured.py +++ b/tools/codegen/api/structured.py @@ -1,7 +1,6 @@ from tools.codegen.model import (Argument, BaseTy, BaseType, ListType, NativeFunctionsGroup, OptionalType, - SelfArgument, TensorOptionsArguments, Type, - assert_never) + SelfArgument, TensorOptionsArguments, Type) from tools.codegen.api.types import (ArgName, BaseCType, Binding, ArrayRefCType, ConstRefCType, OptionalCType, NamedCType, @@ -9,6 +8,7 @@ from tools.codegen.api.types import (ArgName, BaseCType, Binding, ArrayRefCType, optionalTensorRefT, optionalScalarRefT) from tools.codegen.api import cpp +from tools.codegen.utils import assert_never from typing import Union, List diff --git a/tools/codegen/dest/register_dispatch_key.py b/tools/codegen/dest/register_dispatch_key.py index 7b828dc4657..406496a3b82 100644 --- a/tools/codegen/dest/register_dispatch_key.py +++ b/tools/codegen/dest/register_dispatch_key.py @@ -5,11 +5,11 @@ from dataclasses import dataclass import textwrap from tools.codegen.context import method_with_native_function, native_function_manager -from tools.codegen.utils import Target, mapMaybe +from tools.codegen.utils import Target, mapMaybe, assert_never from tools.codegen.model import (DispatchKey, NativeFunction, NativeFunctionsGroup, SchemaKind, TensorOptionsArguments, - DeviceCheckType, Argument, assert_never, + DeviceCheckType, Argument, is_cuda_dispatch_key, BackendIndex, gets_generated_out_inplace_wrapper) from tools.codegen.api.types import (BaseCType, Binding, ConstRefCType, diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index c986f831160..0e48a9b921f 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -1,23 +1,20 @@ import os -from typing import List, Dict, Optional, Tuple, Set, Callable, Any, Union, Sequence, TypeVar, Iterable +from typing import List, Dict, Optional, Tuple, Set, Any, Union, Sequence, TypeVar from typing_extensions import Literal import yaml from collections import OrderedDict, defaultdict, namedtuple import argparse import pathlib -import functools import json from dataclasses import dataclass -import hashlib -from tools.codegen.code_template import CodeTemplate from tools.codegen.model import (Argument, DispatchKey, FunctionSchema, Location, NativeFunction, NativeFunctionsGroup, OperatorName, BackendIndex, BackendMetadata, OptionalType, SchemaKind, SelfArgument, TensorOptionsArguments, Type, Variant, - assert_never, is_cuda_dispatch_key, + is_cuda_dispatch_key, is_generic_dispatch_key) from tools.codegen.api.types import (Binding, CppSignature, CppSignatureGroup, DispatcherSignature, NativeSignature) @@ -28,7 +25,9 @@ import tools.codegen.api.meta as meta import tools.codegen.api.structured as structured from tools.codegen.api.translate import translate from tools.codegen.selective_build.selector import SelectiveBuilder -from tools.codegen.utils import Target, concatMap, context, mapMaybe, YamlDumper, YamlLoader +from tools.codegen.utils import ( + Target, concatMap, context, mapMaybe, YamlDumper, YamlLoader, FileManager, assert_never +) from tools.codegen.context import (method_with_native_function, native_function_manager, with_native_function_and_indices, @@ -884,131 +883,6 @@ def compute_registration_declarations(f: NativeFunction, backend_indices: Dict[D # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -@functools.lru_cache(maxsize=None) -def _read_template(template_fn: str) -> CodeTemplate: - return CodeTemplate.from_file(template_fn) - - -# String hash that's stable across different executions, unlike builtin hash -def string_stable_hash(s: str) -> int: - sha1 = hashlib.sha1(s.encode('latin1')).digest() - return int.from_bytes(sha1, byteorder='little') - -# A small abstraction for writing out generated files and keeping track -# of what files have been written (so you can write out a list of output -# files) -class FileManager: - install_dir: str - template_dir: str - dry_run: bool - filenames: Set[str] - - def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None: - self.install_dir = install_dir - self.template_dir = template_dir - self.filenames = set() - self.dry_run = dry_run - - def _write_if_changed(self, filename: str, contents: str) -> None: - old_contents: Optional[str] - try: - with open(filename, 'r') as f: - old_contents = f.read() - except IOError: - old_contents = None - if contents != old_contents: - with open(filename, 'w') as f: - f.write(contents) - - def write_with_template(self, filename: str, template_fn: str, - env_callable: Callable[[], Union[str, Dict[str, Any]]]) -> None: - filename = '{}/{}'.format(self.install_dir, filename) - assert filename not in self.filenames, "duplicate file write {filename}" - self.filenames.add(filename) - if not self.dry_run: - env = env_callable() - if isinstance(env, dict): - # TODO: Update the comment reference to the correct location - if 'generated_comment' not in env: - comment = "@" + "generated by tools/codegen/gen.py" - comment += " from {}".format(os.path.basename(template_fn)) - env['generated_comment'] = comment - template = _read_template(os.path.join(self.template_dir, template_fn)) - self._write_if_changed(filename, template.substitute(env)) - elif isinstance(env, str): - self._write_if_changed(filename, env) - else: - assert_never(env) - - - def write(self, filename: str, env_callable: Callable[[], Union[str, Union[str, Dict[str, Any]]]]) -> None: - self.write_with_template(filename, filename, env_callable) - - def write_sharded( - self, - filename: str, - items: Iterable[T], - *, - key_fn: Callable[[T], str], - env_callable: Callable[[T], Dict[str, List[str]]], - num_shards: int, - base_env: Optional[Dict[str, Any]] = None, - sharded_keys: Set[str] - ) -> None: - - everything: Dict[str, Any] = {'shard_id': 'Everything'} - shards: List[Dict[str, Any]] = [{'shard_id': f'_{i}'} for i in range(num_shards)] - all_shards = [everything] + shards - - if base_env is not None: - for shard in all_shards: - shard.update(base_env) - - for key in sharded_keys: - for shard in all_shards: - if key in shard: - assert isinstance(shard[key], list), "sharded keys in base_env must be a list" - shard[key] = shard[key].copy() - else: - shard[key] = [] - - - def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None: - for k, v in from_.items(): - assert k in sharded_keys, f"undeclared sharded key {k}" - into[k] += v - - for item in items: - key = key_fn(item) - sid = string_stable_hash(key) % num_shards - env = env_callable(item) - - merge_env(shards[sid], env) - merge_env(everything, env) - - dot_pos = filename.rfind('.') - if dot_pos == -1: - dot_pos = len(filename) - base_filename = filename[:dot_pos] - extension = filename[dot_pos:] - - for shard in all_shards: - shard_id = shard['shard_id'] - self.write_with_template(f"{base_filename}{shard_id}{extension}", - filename, - lambda: shard) - - # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled - self.filenames.discard( - f"{self.install_dir}/{base_filename}Everything{extension}") - - def write_outputs(self, filename: str) -> None: - """Write a file containing the list of all outputs which are - generated by this script.""" - self._write_if_changed( - filename, - ''.join(name + ";" for name in sorted(self.filenames))) - def get_custom_build_selector( provided_op_registration_allowlist: Optional[List[str]], op_selection_yaml_path: Optional[str]) -> SelectiveBuilder: diff --git a/tools/codegen/gen_backend_stubs.py b/tools/codegen/gen_backend_stubs.py index 5fad11c3438..5e8da81f9e4 100644 --- a/tools/codegen/gen_backend_stubs.py +++ b/tools/codegen/gen_backend_stubs.py @@ -5,11 +5,11 @@ import yaml import re from collections import namedtuple, Counter, defaultdict from typing import List, Dict, Union, Sequence, Optional -from tools.codegen.gen import FileManager, get_grouped_native_functions, parse_native_yaml +from tools.codegen.gen import get_grouped_native_functions, parse_native_yaml from tools.codegen.model import (BackendIndex, BackendMetadata, DispatchKey, NativeFunction, NativeFunctionsGroup, OperatorName) from tools.codegen.selective_build.selector import SelectiveBuilder -from tools.codegen.utils import Target, concatMap, context, YamlLoader +from tools.codegen.utils import Target, concatMap, context, YamlLoader, FileManager from tools.codegen.context import native_function_manager import tools.codegen.dest as dest import tools.codegen.api.dispatcher as dispatcher diff --git a/tools/codegen/model.py b/tools/codegen/model.py index e604e72d3a1..2578d4526fe 100644 --- a/tools/codegen/model.py +++ b/tools/codegen/model.py @@ -1,16 +1,12 @@ import re +from tools.codegen.utils import assert_never + from dataclasses import dataclass -from typing import List, Dict, Optional, Iterator, Tuple, Set, NoReturn, Sequence, Callable, Union +from typing import List, Dict, Optional, Iterator, Tuple, Set, Sequence, Callable, Union from enum import Enum, auto import itertools -# A little trick from https://github.com/python/mypy/issues/6366 -# for getting mypy to do exhaustiveness checking -# TODO: put this somewhere else, maybe -def assert_never(x: NoReturn) -> NoReturn: - raise AssertionError("Unhandled type: {}".format(type(x).__name__)) - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # # DATA MODEL diff --git a/tools/codegen/utils.py b/tools/codegen/utils.py index e2ef6a45612..c96b6cdd6c2 100644 --- a/tools/codegen/utils.py +++ b/tools/codegen/utils.py @@ -1,8 +1,13 @@ import re -from typing import Tuple, List, Iterable, Iterator, Callable, Sequence, TypeVar, Optional +import os +from typing import Tuple, List, Iterable, Iterator, Callable, Sequence, TypeVar, Optional, Dict, Any, Union, Set, NoReturn from enum import Enum import contextlib import textwrap +import hashlib +import functools + +from tools.codegen.code_template import CodeTemplate # Safely load fast C Yaml loader/dumper if they are available try: @@ -94,3 +99,134 @@ def context(msg_fn: Callable[[], str]) -> Iterator[None]: msg = f'{e.args[0]}\n{msg}' if e.args else msg e.args = (msg,) + e.args[1:] raise + +# A little trick from https://github.com/python/mypy/issues/6366 +# for getting mypy to do exhaustiveness checking +# TODO: put this somewhere else, maybe +def assert_never(x: NoReturn) -> NoReturn: + raise AssertionError("Unhandled type: {}".format(type(x).__name__)) + +@functools.lru_cache(maxsize=None) +def _read_template(template_fn: str) -> CodeTemplate: + return CodeTemplate.from_file(template_fn) + + +# String hash that's stable across different executions, unlike builtin hash +def string_stable_hash(s: str) -> int: + sha1 = hashlib.sha1(s.encode('latin1')).digest() + return int.from_bytes(sha1, byteorder='little') + +# A small abstraction for writing out generated files and keeping track +# of what files have been written (so you can write out a list of output +# files) +class FileManager: + install_dir: str + template_dir: str + dry_run: bool + filenames: Set[str] + + def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None: + self.install_dir = install_dir + self.template_dir = template_dir + self.filenames = set() + self.dry_run = dry_run + + def _write_if_changed(self, filename: str, contents: str) -> None: + old_contents: Optional[str] + try: + with open(filename, 'r') as f: + old_contents = f.read() + except IOError: + old_contents = None + if contents != old_contents: + with open(filename, 'w') as f: + f.write(contents) + + def write_with_template(self, filename: str, template_fn: str, + env_callable: Callable[[], Union[str, Dict[str, Any]]]) -> None: + filename = '{}/{}'.format(self.install_dir, filename) + assert filename not in self.filenames, "duplicate file write {filename}" + self.filenames.add(filename) + if not self.dry_run: + env = env_callable() + if isinstance(env, dict): + # TODO: Update the comment reference to the correct location + if 'generated_comment' not in env: + comment = "@" + "generated by tools/codegen/gen.py" + comment += " from {}".format(os.path.basename(template_fn)) + env['generated_comment'] = comment + template = _read_template(os.path.join(self.template_dir, template_fn)) + self._write_if_changed(filename, template.substitute(env)) + elif isinstance(env, str): + self._write_if_changed(filename, env) + else: + assert_never(env) + + + def write(self, filename: str, env_callable: Callable[[], Union[str, Union[str, Dict[str, Any]]]]) -> None: + self.write_with_template(filename, filename, env_callable) + + def write_sharded( + self, + filename: str, + items: Iterable[T], + *, + key_fn: Callable[[T], str], + env_callable: Callable[[T], Dict[str, List[str]]], + num_shards: int, + base_env: Optional[Dict[str, Any]] = None, + sharded_keys: Set[str] + ) -> None: + + everything: Dict[str, Any] = {'shard_id': 'Everything'} + shards: List[Dict[str, Any]] = [{'shard_id': f'_{i}'} for i in range(num_shards)] + all_shards = [everything] + shards + + if base_env is not None: + for shard in all_shards: + shard.update(base_env) + + for key in sharded_keys: + for shard in all_shards: + if key in shard: + assert isinstance(shard[key], list), "sharded keys in base_env must be a list" + shard[key] = shard[key].copy() + else: + shard[key] = [] + + + def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None: + for k, v in from_.items(): + assert k in sharded_keys, f"undeclared sharded key {k}" + into[k] += v + + for item in items: + key = key_fn(item) + sid = string_stable_hash(key) % num_shards + env = env_callable(item) + + merge_env(shards[sid], env) + merge_env(everything, env) + + dot_pos = filename.rfind('.') + if dot_pos == -1: + dot_pos = len(filename) + base_filename = filename[:dot_pos] + extension = filename[dot_pos:] + + for shard in all_shards: + shard_id = shard['shard_id'] + self.write_with_template(f"{base_filename}{shard_id}{extension}", + filename, + lambda: shard) + + # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled + self.filenames.discard( + f"{self.install_dir}/{base_filename}Everything{extension}") + + def write_outputs(self, filename: str) -> None: + """Write a file containing the list of all outputs which are + generated by this script.""" + self._write_if_changed( + filename, + ''.join(name + ";" for name in sorted(self.filenames))) diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 47447bd8e44..332e784f0b0 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -6,7 +6,8 @@ import argparse from tools.codegen.model import Variant from tools.codegen.api.python import (PythonSignatureGroup, PythonSignatureNativeFunctionPair) -from tools.codegen.gen import FileManager, parse_native_yaml +from tools.codegen.gen import parse_native_yaml +from tools.codegen.utils import FileManager from typing import Sequence, List, Dict from ..autograd.gen_python_functions import should_generate_py_binding, load_signatures, group_overloads