From 2765243cd5e657a92142d09504dafadb058de63f Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Sat, 10 Sep 2022 06:58:56 +0000 Subject: [PATCH] [torchgen] Refactor static_dispatch to take in source signature (#84384) Summary: Context: currently `static_dispatch` assumes that given a native function `f`, we always want to map from its `DispatchSignature` to its `CppSignature`. This assumption may not hold true for some use cases, where the source bindings may not come from its `DispatchSignature`. Here I'm changing the argument `sig: DispatcherSignature` to be `sig: Union[CppSignature, DispatcherSignature]`, also removes unused `f` Test Plan: Rely on added unit test. Differential Revision: D39192969 Pull Request resolved: https://github.com/pytorch/pytorch/pull/84384 Approved by: https://github.com/iseeyuan --- tools/test/test_codegen.py | 62 ++++++++++++++++++++++++++++++++++++++ torchgen/gen.py | 41 +++++++++++++++---------- 2 files changed, 87 insertions(+), 16 deletions(-) diff --git a/tools/test/test_codegen.py b/tools/test/test_codegen.py index cbce8b3bc5e..8bcecbb26e3 100644 --- a/tools/test/test_codegen.py +++ b/tools/test/test_codegen.py @@ -9,10 +9,13 @@ import torchgen.model import yaml from tools.autograd import gen_autograd_functions, load_derivatives +from torchgen.api.types import CppSignatureGroup, DispatcherSignature +from torchgen.context import native_function_manager from torchgen.gen import ( get_native_function_declarations, get_native_function_schema_registrations, LineLoader, + static_dispatch, ) from torchgen.model import ( BackendIndex, @@ -408,6 +411,65 @@ class TestNativeFunctionGeneratrion(unittest.TestCase): self.assertEqual(backend_metadata.kernel, "op_2_out") +# Test for static_dispatch +class TestStaticDispatchGeneratrion(unittest.TestCase): + def setUp(self) -> None: + self.backend_indices: Dict[ + DispatchKey, Dict[OperatorName, BackendMetadata] + ] = defaultdict(dict) + yaml_entry = """ +- func: op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: op + """ + es = yaml.load(yaml_entry, Loader=LineLoader) + self.one_return_func, m = NativeFunction.from_yaml( + es[0], loc=Location(__file__, 1), valid_tags=set() + ) + + BackendIndex.grow_index(self.backend_indices, m) + dispatch_key = DispatchKey.CompositeExplicitAutograd + self.assertTrue(dispatch_key in self.backend_indices) + self.indices = [ + BackendIndex( + dispatch_key=dispatch_key, + use_out_as_primary=True, + external=False, + device_guard=False, + index=self.backend_indices[dispatch_key], + ) + ] + + def test_op_with_1_backend_generates_static_dispatch(self) -> None: + disp_sig = DispatcherSignature.from_schema(self.one_return_func.func) + with native_function_manager(self.one_return_func): + out = static_dispatch( + sig=disp_sig, + f=self.one_return_func, + backend_indices=self.indices, + ) + self.assertEqual( + out, "return at::compositeexplicitautograd::op_out(out, self);" + ) + + def test_op_with_cpp_sig_generates_static_dispatch(self) -> None: + sig_group = CppSignatureGroup.from_native_function( + self.one_return_func, + method=False, + fallback_binding=self.one_return_func.manual_cpp_binding, + ) + # cpp signature puts out at the front + with native_function_manager(self.one_return_func): + out = static_dispatch( + sig=sig_group.signature, + f=self.one_return_func, + backend_indices=self.indices, + ) + self.assertEqual( + out, "return at::compositeexplicitautograd::op_out(out, self);" + ) + + # Represents the most basic NativeFunction. Use dataclasses.replace() # to edit for use. DEFAULT_NATIVE_FUNCTION, _ = torchgen.model.NativeFunction.from_yaml( diff --git a/torchgen/gen.py b/torchgen/gen.py index 0d535275660..4ea126e5433 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -349,13 +349,12 @@ def static_dispatch_extra_headers(backends: List[BackendIndex]) -> List[str]: ] -# Translates arguments of a native function from DispatcherSignature form to CppSignature form with support for -# supporting usecases even when there is a memory_format argument along with tensor_option arguments. -# This usecase is not covered by tools.codegen.api.translate() yet as its application is limited to static dispatch -def translate_args_dispatcher_to_cpp( - sig: DispatcherSignature, +# Translates arguments of `sig` to CppSignature bindings. +# Note that we have a special case for `memory_format` argument and this case is not covered by +# tools.codegen.api.translate() yet as its application is limited to static dispatch. +def translate_args( + sig: Union[CppSignature, DispatcherSignature], cpp_sig: CppSignature, - f: NativeFunction, ) -> str: # Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings @@ -377,20 +376,20 @@ def translate_args_dispatcher_to_cpp( output_bindings.append(binding) return output_bindings - disp_sig = sig - disp_bindings = disp_sig.arguments() + src_bindings = list(sig.arguments()) + goal_bindings = list(cpp_sig.arguments()) # When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType, # get memory_format bindings of dispatcher signature to have the same NCType as well - for arg in cpp_sig.arguments(): + for arg in goal_bindings: if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format: - disp_bindings = add_spl_memory_format_binding(disp_sig.arguments()) + src_bindings = add_spl_memory_format_binding(src_bindings) break - exprs = translate(disp_bindings, cpp_sig.arguments()) + exprs = translate(src_bindings, goal_bindings) return ", ".join(a.expr for a in exprs) def generate_static_dispatch_backend_call( - sig: DispatcherSignature, + sig: Union[CppSignature, DispatcherSignature], f: NativeFunction, backend_index: BackendIndex, ) -> str: @@ -403,7 +402,7 @@ def generate_static_dispatch_backend_call( cpp_sig = cpp_sigs.signature assert cpp_sig is not None name = cpp_sig.name() - exprs = translate_args_dispatcher_to_cpp(sig, cpp_sig, f) + exprs = translate_args(sig, cpp_sig) backend_metadata = backend_index.get_kernel(f) kernel_ns = ( backend_metadata.cpp_namespace @@ -415,7 +414,7 @@ def generate_static_dispatch_backend_call( def generate_static_dispatch_fallback_call( - sig: DispatcherSignature, + sig: Union[CppSignature, DispatcherSignature], f: NativeFunction, backend_indices: List[BackendIndex], ) -> str: @@ -428,7 +427,7 @@ def generate_static_dispatch_fallback_call( cpp_sig = cpp_sigs.signature assert cpp_sig is not None name = cpp_sig.name() - exprs = translate_args_dispatcher_to_cpp(sig, cpp_sig, f) + exprs = translate_args(sig, cpp_sig) ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "") if f.has_composite_explicit_autograd_kernel: return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});" @@ -444,10 +443,20 @@ def generate_static_dispatch_fallback_call( def static_dispatch( - sig: DispatcherSignature, + sig: Union[CppSignature, DispatcherSignature], f: NativeFunction, backend_indices: List[BackendIndex], ) -> str: + """ + For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one + backends exsit, fallback to static dispatch by determining dispatch key from inputs. + Arguments: + sig: A CppSignature or DispatcherSignature for this native function we want to use. + f: NativeFunction to generate static dispatch. + backend_indices: All available backends. + Return: + C++ code to call backend-specific functions, e.g., "return at::cpu::add(self, other, scale);" + """ if len(backend_indices) == 0 or f.manual_kernel_registration: return ""