From 406ce692cafd045daa3d2ebb496557e653541ae6 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Thu, 4 Aug 2022 07:48:44 +0000 Subject: [PATCH] [torchgen] Generate wrapper functions under custom namespaces (#81744) Summary: A follow up of #81581. Before these 2 PRs, if an operator with custom kernel namespace is added to `native_functions.yaml` (or any other yaml consumed by `torchgen`), although we are able to recognize the custom kernel in files such as `NativeFunctions.h` and `RegisterCPU.cpp`, we still generate backend specific wrappers under the hardcoded `at` namespace. This changes the behavior, by generating wrapper functions under custom namespaces. For example, if the entries in yaml file looks like: ``` - func: op_1(Tensor(a) self) -> Tensor(a) dispatch: CPU: at::op_1_kernel # ATen kernel - func: op_2(Tensor(a) self) -> Tensor(a) dispatch: CPU: custom::op_2_kernel # custom kernel ``` We generate the following code for `CPUFunctions_inl.h` and `RegisterCPU.cpp`: `CPUFunctions_inl.h`: ``` namespace at { namespace cpu { TORCH_API at::Tensor & op_1(const at::Tensor & self); } // namespace cpu } // namespace at namespace custom { namespace cpu { TORCH_API at::Tensor & op_2(const at::Tensor & self); } // namespace cpu } // namespace custom ``` Notice the difference between `at::cpu` and `custom::cpu`. Then the definition for these can be found in `RegisterCPU.cpp`. `RegisterCPU.cpp`: ``` #include "CPUFunctions.h" namespace at { namespace { at::Tensor & wrapper_op_1(const at::Tensor & self) { // No device check // DeviceGuard omitted return at::native::op_1_kernel(self); } } // anonymous namespace TORCH_LIBRARY_IMPL(aten, CPU, m) { m.impl("op_1", TORCH_FN(wrapper_op_1)); } namespace cpu { at::Tensor & op_1(at::Tensor & self) { return wrapper_op_1(self); } } // namespace cpu } // namespace at namespace custom { namespace { at::Tensor & wrapper_op_2(const at::Tensor & self) { // No device check // DeviceGuard omitted return at::native::op_2_kernel(self); } } // anonymous namespace TORCH_LIBRARY_IMPL(aten, CPU, m) { m.impl("op_2", TORCH_FN(wrapper_op_2)); } namespace cpu { at::Tensor & op_2(at::Tensor & self) { return wrapper_op_2(self); } } // namespace cpu } // namespace custom ``` The benefit for this change is that it unifies all the namespaces derived from custom ops. In the example above, there are: 1. `custom::native` for kernels 2. `custom::` e.g., `custom::cpu` for wrappers This customized operator will have nothing to do with `at::native`, `at::cpu` etc. Test Plan: This is very hard to test. I will refactor this logic, abstract out some layers so it's testable. Will do it in coming PRs Differential Revision: D37972772 Pull Request resolved: https://github.com/pytorch/pytorch/pull/81744 Approved by: https://github.com/bdhirsh --- BUILD.bazel | 1 + .../ATen/templates/DispatchKeyFunctions_inl.h | 5 - .../templates/RegisterDispatchDefinitions.ini | 24 ++ .../ATen/templates/RegisterDispatchKey.cpp | 27 +- build.bzl | 1 + torchgen/gen.py | 250 +++++++++++++----- torchgen/gen_backend_stubs.py | 42 +-- torchgen/utils.py | 37 ++- 8 files changed, 259 insertions(+), 128 deletions(-) create mode 100644 aten/src/ATen/templates/RegisterDispatchDefinitions.ini diff --git a/BUILD.bazel b/BUILD.bazel index 823a59bb63b..4c0791bffbb 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1877,6 +1877,7 @@ test_suite( "aten/src/ATen/templates/LazyIr.h", "aten/src/ATen/templates/LazyNonNativeIr.h", "aten/src/ATen/templates/RegisterDispatchKey.cpp", + "aten/src/ATen/templates/RegisterDispatchDefinitions.ini", "aten/src/ATen/native/native_functions.yaml", "aten/src/ATen/native/tags.yaml", "aten/src/ATen/native/ts_native_functions.yaml", diff --git a/aten/src/ATen/templates/DispatchKeyFunctions_inl.h b/aten/src/ATen/templates/DispatchKeyFunctions_inl.h index 73bc1008a4f..fbb71c2cb12 100644 --- a/aten/src/ATen/templates/DispatchKeyFunctions_inl.h +++ b/aten/src/ATen/templates/DispatchKeyFunctions_inl.h @@ -18,10 +18,5 @@ ${DispatchKeyFunctions_inl_includes} -namespace at { -namespace ${dispatch_namespace} { ${dispatch_namespaced_declarations} - -} // namespace ${dispatch_namespace} -} // namespace at diff --git a/aten/src/ATen/templates/RegisterDispatchDefinitions.ini b/aten/src/ATen/templates/RegisterDispatchDefinitions.ini new file mode 100644 index 00000000000..3bf7f9b1bb3 --- /dev/null +++ b/aten/src/ATen/templates/RegisterDispatchDefinitions.ini @@ -0,0 +1,24 @@ +${ns_prologue} + +// NB: TORCH_LIBRARY_IMPL must be in an anonymous namespace to avoid +// ambiguity with conflicting identifiers that may have been defined in +// at namespace already. +namespace { + +${dispatch_helpers} + +${dispatch_anonymous_definitions} + +${static_init_dispatch_registrations} + +} // anonymous namespace + +${deferred_dispatch_registrations} + +namespace ${dispatch_namespace} { + +${dispatch_namespaced_definitions} + +} // namespace ${dispatch_namespace} + +${ns_epilogue} diff --git a/aten/src/ATen/templates/RegisterDispatchKey.cpp b/aten/src/ATen/templates/RegisterDispatchKey.cpp index df00c0d0e4a..7a1584d505f 100644 --- a/aten/src/ATen/templates/RegisterDispatchKey.cpp +++ b/aten/src/ATen/templates/RegisterDispatchKey.cpp @@ -50,28 +50,5 @@ $external_backend_headers $dispatch_headers $ops_headers - -namespace at { - -// NB: TORCH_LIBRARY_IMPL must be in an anonymous namespace to avoid -// ambiguity with conflicting identifiers that may have been defined in -// at namespace already. -namespace { - -${dispatch_helpers} - -${dispatch_anonymous_definitions} - -${static_init_dispatch_registrations} - -} // anonymous namespace - -${deferred_dispatch_registrations} - -namespace ${dispatch_namespace} { - -${dispatch_namespaced_definitions} - -} // namespace ${dispatch_namespace} - -} // namespace at +// See template file RegisterDispatchDefinitions.ini +$dispatch_definitions diff --git a/build.bzl b/build.bzl index ac9ceaa0559..5715e34786d 100644 --- a/build.bzl +++ b/build.bzl @@ -92,6 +92,7 @@ def define_targets(rules): ":LazyIr.h", ":LazyNonNativeIr.h", ":RegisterDispatchKey.cpp", + ":RegisterDispatchDefinitions.ini", ":native_functions.yaml", ":shape_inference.h", ":tags.yaml", diff --git a/torchgen/gen.py b/torchgen/gen.py index 7606bde5215..ba55471d047 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -15,6 +15,7 @@ import torchgen.api.meta as meta import torchgen.api.native as native import torchgen.api.structured as structured import torchgen.dest as dest + from torchgen.api import cpp from torchgen.api.translate import translate from torchgen.api.types import ( @@ -1408,6 +1409,168 @@ def get_native_function_declarations( return declarations +def get_kernel_namespace( + *, f: Union[NativeFunction, NativeFunctionsGroup], backend_idx: BackendIndex +) -> str: + backend_metadata = backend_idx.get_kernel(f) + assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, ( + f"The kernel for function {f.func.name if isinstance(f, NativeFunction) else f.functional.func.name} " + f"with dispatch key {backend_idx.dispatch_key}" + f" has a namespace {backend_metadata.cpp_namespace} and it's not ending with '::native'." + ) + return ( + backend_metadata.cpp_namespace if backend_metadata else DEFAULT_KERNEL_NAMESPACE + ) + + +# Return native function definitions grouped by dispatch key and custom namespace. +# Used in RegisterDispatchKey.cpp and etc. +def get_native_function_definitions( + *, + fm: FileManager, + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], + dispatch_key: DispatchKey, + backend_idx: BackendIndex, + selector: SelectiveBuilder, + rocm: bool, + skip_dispatcher_op_registration: bool, + gen_dispatch_helpers: bool, +) -> List[str]: + definitions: List[str] = [] + ns_definitions: Dict[str, List[str]] = defaultdict(list) + anonymous_definitions: Dict[str, List[str]] = defaultdict(list) + registrations: Dict[str, Dict[str, List[str]]] = defaultdict(dict) + newline = "\n" + ns_gen = dest.RegisterDispatchKey( + backend_idx, + Target.NAMESPACED_DEFINITION, + selector, + rocm=rocm, + class_method_name=None, + skip_dispatcher_op_registration=skip_dispatcher_op_registration, + ) + anonymous_gen = dest.RegisterDispatchKey( + backend_idx, + Target.ANONYMOUS_DEFINITION, + selector, + rocm=rocm, + class_method_name=None, + skip_dispatcher_op_registration=skip_dispatcher_op_registration, + ) + reg_gen = dest.RegisterDispatchKey( + backend_idx, + Target.REGISTRATION, + selector, + rocm=rocm, + class_method_name=None, + skip_dispatcher_op_registration=skip_dispatcher_op_registration, + ) + for f in grouped_native_functions: + kernel_namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace( + "::native", "" + ) + + ns_definitions[kernel_namespace].extend( + ns_gen(f), + ) + anonymous_definitions[kernel_namespace].extend( + anonymous_gen(f), + ) + namespace = ( + f.namespace if isinstance(f, NativeFunction) else f.functional.namespace + ) + if namespace not in registrations[kernel_namespace]: + registrations[kernel_namespace] = defaultdict(list) + registrations[kernel_namespace][namespace].extend( + reg_gen(f), + ) + + for kernel_namespace in ns_definitions: + if len(ns_definitions[kernel_namespace]) == 0: + continue + ns_helper = NamespaceHelper(namespace_str=kernel_namespace) + registration_body = "" + for namespace in registrations[kernel_namespace]: + if not registrations[kernel_namespace][namespace]: + continue + registration_body += f""" +TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{ + {newline.join(registrations[kernel_namespace][namespace])} +}};""" + definitions.extend( + fm.substitute_with_template( + "RegisterDispatchDefinitions.ini", + lambda: { + "ns_prologue": ns_helper.prologue, + "ns_epilogue": ns_helper.epilogue, + "dispatch_helpers": dest.gen_registration_helpers(backend_idx) + if gen_dispatch_helpers + else [], + "dispatch_anonymous_definitions": anonymous_definitions[ + kernel_namespace + ], + "static_init_dispatch_registrations": "" + if skip_dispatcher_op_registration + else registration_body, + "deferred_dispatch_registrations": "", + "dispatch_namespace": dispatch_key.lower(), + "dispatch_namespaced_definitions": ns_definitions[kernel_namespace], + }, + ).split(newline) + ) + + return definitions + + +# Return native function declarations grouped by dispatch key and custom namespace. +# Used in CPUFunctions_inl.h and etc. +def get_namespaced_declaration( + *, + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], + dispatch_key: DispatchKey, + backend_idx: BackendIndex, + selector: SelectiveBuilder, + rocm: bool, +) -> List[str]: + declarations: List[str] = [] + ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list) + newline = "\n" + func = dest.RegisterDispatchKey( + backend_idx, + Target.NAMESPACED_DECLARATION, + selector, + rocm=rocm, + class_method_name=None, + skip_dispatcher_op_registration=False, + ) + for f in grouped_native_functions: + namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace( + "native", dispatch_key.lower() + ) + + ns_grouped_kernels[namespace].extend( + func(f), + ) + + for namespace, kernels in ns_grouped_kernels.items(): + if len(kernels) == 0: + continue + ns_helper = NamespaceHelper( + namespace_str=namespace, entity_name="", max_level=3 + ) + ordered_kernels = list(OrderedDict.fromkeys(kernels)) + declarations.extend( + f""" +{ns_helper.prologue} +{newline.join(ordered_kernels)} +{ns_helper.epilogue} + """.split( + newline + ) + ) + return declarations + + # Return native function schema registration code for aten and other namespaces. def get_native_function_schema_registrations( *, @@ -1550,18 +1713,12 @@ def gen_aggregated_headers( lambda: { "DispatchKeyFunctions_inl_includes": [], "dispatch_namespace": dispatch_key.lower(), - "dispatch_namespaced_declarations": list( - concatMap( - dest.RegisterDispatchKey( - backend_indices[dispatch_key], - Target.NAMESPACED_DECLARATION, - selector, - rocm=rocm, - class_method_name=None, - skip_dispatcher_op_registration=False, - ), - grouped_native_functions, - ) + "dispatch_namespaced_declarations": get_namespaced_declaration( + grouped_native_functions=grouped_native_functions, + dispatch_key=dispatch_key, + backend_idx=backend_indices[dispatch_key], + selector=selector, + rocm=rocm, ), }, ) @@ -1998,33 +2155,17 @@ def gen_source_files( ) ns_grouped_native_functions[namespace].append(grouped_native_function) - static_init_dispatch_registrations = "" - for namespace, functions in ns_grouped_native_functions.items(): - dispatch_registrations_body = ( - "" - if skip_dispatcher_op_registration - else "\n".join( - list( - concatMap( - dest.RegisterDispatchKey( - backend_index, - Target.REGISTRATION, - selector, - rocm=rocm, - class_method_name=None, - skip_dispatcher_op_registration=skip_dispatcher_op_registration, - ), - functions, - ) - ) - ) - ) - - static_init_dispatch_registrations += f""" -TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{ - {dispatch_registrations_body} -}};""" dispatch_namespace = str(dispatch_key).lower() + dispatch_definitions = get_native_function_definitions( + fm=fm, + grouped_native_functions=grouped_native_functions, + dispatch_key=dispatch_key, + backend_idx=backend_index, + selector=selector, + rocm=rocm, + skip_dispatcher_op_registration=skip_dispatcher_op_registration, + gen_dispatch_helpers=True, + ) fm.write_with_template( f"Register{dispatch_key}.cpp", "RegisterDispatchKey.cpp", @@ -2037,37 +2178,8 @@ TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{ backend_index, per_operator_headers, rocm ), "ops_headers": operator_headers(), - "DispatchKey": dispatch_key, - "dispatch_namespace": dispatch_key.lower(), - "dispatch_helpers": dest.gen_registration_helpers(backend_index), - "dispatch_namespaced_definitions": list( - concatMap( - dest.RegisterDispatchKey( - backend_index, - Target.NAMESPACED_DEFINITION, - selector, - rocm=rocm, - class_method_name=None, - skip_dispatcher_op_registration=skip_dispatcher_op_registration, - ), - grouped_native_functions, - ) - ), - "dispatch_anonymous_definitions": list( - concatMap( - dest.RegisterDispatchKey( - backend_index, - Target.ANONYMOUS_DEFINITION, - selector, - rocm=rocm, - class_method_name=None, - skip_dispatcher_op_registration=skip_dispatcher_op_registration, - ), - grouped_native_functions, - ) - ), - "static_init_dispatch_registrations": static_init_dispatch_registrations, - "deferred_dispatch_registrations": "", + "dispatch_helpers": "", + "dispatch_definitions": dispatch_definitions, }, ) diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py index ae346e6d3ac..37b4048146c 100644 --- a/torchgen/gen_backend_stubs.py +++ b/torchgen/gen_backend_stubs.py @@ -422,6 +422,8 @@ def gen_dispatcher_registrations( grouped_native_functions, ) ) + newline = "\n" + ns_helper = NamespaceHelper(namespace_str="at") deferred_dispatch_registrations = "" static_init_dispatch_registrations = "" if eager_registration: @@ -453,8 +455,6 @@ TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() { f"Register{dispatch_key}.cpp", "RegisterDispatchKey.cpp", lambda: { - "static_init_dispatch_registrations": static_init_dispatch_registrations, - "deferred_dispatch_registrations": deferred_dispatch_registrations, "extra_cuda_headers": "", "external_backend_headers": external_backend_headers_str, "ops_headers": "#include " @@ -465,21 +465,31 @@ TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() { "dispatch_headers": dest.gen_registration_headers( backend_index, per_operator_headers=per_operator_headers, rocm=False ), - "dispatch_helpers": dest.gen_registration_helpers(backend_index), - "dispatch_namespaced_definitions": "", - "dispatch_anonymous_definitions": list( - concatMap( - dest.RegisterDispatchKey( - backend_index, - Target.ANONYMOUS_DEFINITION, - selector, - rocm=False, - class_method_name=f"{class_name}", - skip_dispatcher_op_registration=False, + "dispatch_definitions": fm.substitute_with_template( + "RegisterDispatchDefinitions.ini", + lambda: { + "ns_prologue": ns_helper.prologue, + "ns_epilogue": ns_helper.epilogue, + "static_init_dispatch_registrations": static_init_dispatch_registrations, + "deferred_dispatch_registrations": deferred_dispatch_registrations, + "dispatch_helpers": dest.gen_registration_helpers(backend_index), + "dispatch_namespace": dispatch_key.lower(), + "dispatch_namespaced_definitions": "", + "dispatch_anonymous_definitions": list( + concatMap( + dest.RegisterDispatchKey( + backend_index, + Target.ANONYMOUS_DEFINITION, + selector, + rocm=False, + class_method_name=f"{class_name}", + skip_dispatcher_op_registration=False, + ), + grouped_native_functions, + ) ), - grouped_native_functions, - ) - ), + }, + ).split(newline), }, ) diff --git a/torchgen/utils.py b/torchgen/utils.py index c168f186f83..64c21817080 100644 --- a/torchgen/utils.py +++ b/torchgen/utils.py @@ -176,6 +176,25 @@ class FileManager: with open(filename, "w") as f: f.write(contents) + # Read from template file and replace pattern with callable (type could be dict or str). + def substitute_with_template( + self, template_fn: str, env_callable: Callable[[], Union[str, Dict[str, Any]]] + ) -> str: + template_path = os.path.join(self.template_dir, template_fn) + env = env_callable() + if isinstance(env, dict): + # TODO: Update the comment reference to the correct location + if "generated_comment" not in env: + comment = "@" + "generated by torchgen/gen.py" + comment += " from {}".format(os.path.basename(template_path)) + env["generated_comment"] = comment + template = _read_template(template_path) + return template.substitute(env) + elif isinstance(env, str): + return env + else: + assert_never(env) + def write_with_template( self, filename: str, @@ -186,19 +205,11 @@ class FileManager: assert filename not in self.filenames, "duplicate file write {filename}" self.filenames.add(filename) if not self.dry_run: - env = env_callable() - if isinstance(env, dict): - # TODO: Update the comment reference to the correct location - if "generated_comment" not in env: - comment = "@" + "generated by torchgen/gen.py" - comment += " from {}".format(os.path.basename(template_fn)) - env["generated_comment"] = comment - template = _read_template(os.path.join(self.template_dir, template_fn)) - self._write_if_changed(filename, template.substitute(env)) - elif isinstance(env, str): - self._write_if_changed(filename, env) - else: - assert_never(env) + substitute_out = self.substitute_with_template( + template_fn=template_fn, + env_callable=env_callable, + ) + self._write_if_changed(filename=filename, contents=substitute_out) def write( self,