[torchgen] Refactor static_dispatch to take in source signature (#84384)

Summary: Context: currently `static_dispatch` assumes that given a native function `f`, we always want to map from its `DispatchSignature` to its `CppSignature`. This assumption may not hold true for some use cases, where the source bindings may not come from its `DispatchSignature`. Here I'm changing the argument `sig: DispatcherSignature` to be `sig: Union[CppSignature, DispatcherSignature]`, also removes unused `f`

Test Plan: Rely on added unit test.

Differential Revision: D39192969

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84384
Approved by: https://github.com/iseeyuan
This commit is contained in:
Mengwei Liu 2022-09-10 06:58:56 +00:00 committed by PyTorch MergeBot
parent c5a8946e40
commit 2765243cd5
2 changed files with 87 additions and 16 deletions

View file

@ -9,10 +9,13 @@ import torchgen.model
import yaml
from tools.autograd import gen_autograd_functions, load_derivatives
from torchgen.api.types import CppSignatureGroup, DispatcherSignature
from torchgen.context import native_function_manager
from torchgen.gen import (
get_native_function_declarations,
get_native_function_schema_registrations,
LineLoader,
static_dispatch,
)
from torchgen.model import (
BackendIndex,
@ -408,6 +411,65 @@ class TestNativeFunctionGeneratrion(unittest.TestCase):
self.assertEqual(backend_metadata.kernel, "op_2_out")
# Test for static_dispatch
class TestStaticDispatchGeneratrion(unittest.TestCase):
def setUp(self) -> None:
self.backend_indices: Dict[
DispatchKey, Dict[OperatorName, BackendMetadata]
] = defaultdict(dict)
yaml_entry = """
- func: op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CompositeExplicitAutograd: op
"""
es = yaml.load(yaml_entry, Loader=LineLoader)
self.one_return_func, m = NativeFunction.from_yaml(
es[0], loc=Location(__file__, 1), valid_tags=set()
)
BackendIndex.grow_index(self.backend_indices, m)
dispatch_key = DispatchKey.CompositeExplicitAutograd
self.assertTrue(dispatch_key in self.backend_indices)
self.indices = [
BackendIndex(
dispatch_key=dispatch_key,
use_out_as_primary=True,
external=False,
device_guard=False,
index=self.backend_indices[dispatch_key],
)
]
def test_op_with_1_backend_generates_static_dispatch(self) -> None:
disp_sig = DispatcherSignature.from_schema(self.one_return_func.func)
with native_function_manager(self.one_return_func):
out = static_dispatch(
sig=disp_sig,
f=self.one_return_func,
backend_indices=self.indices,
)
self.assertEqual(
out, "return at::compositeexplicitautograd::op_out(out, self);"
)
def test_op_with_cpp_sig_generates_static_dispatch(self) -> None:
sig_group = CppSignatureGroup.from_native_function(
self.one_return_func,
method=False,
fallback_binding=self.one_return_func.manual_cpp_binding,
)
# cpp signature puts out at the front
with native_function_manager(self.one_return_func):
out = static_dispatch(
sig=sig_group.signature,
f=self.one_return_func,
backend_indices=self.indices,
)
self.assertEqual(
out, "return at::compositeexplicitautograd::op_out(out, self);"
)
# Represents the most basic NativeFunction. Use dataclasses.replace()
# to edit for use.
DEFAULT_NATIVE_FUNCTION, _ = torchgen.model.NativeFunction.from_yaml(

View file

@ -349,13 +349,12 @@ def static_dispatch_extra_headers(backends: List[BackendIndex]) -> List[str]:
]
# Translates arguments of a native function from DispatcherSignature form to CppSignature form with support for
# supporting usecases even when there is a memory_format argument along with tensor_option arguments.
# This usecase is not covered by tools.codegen.api.translate() yet as its application is limited to static dispatch
def translate_args_dispatcher_to_cpp(
sig: DispatcherSignature,
# Translates arguments of `sig` to CppSignature bindings.
# Note that we have a special case for `memory_format` argument and this case is not covered by
# tools.codegen.api.translate() yet as its application is limited to static dispatch.
def translate_args(
sig: Union[CppSignature, DispatcherSignature],
cpp_sig: CppSignature,
f: NativeFunction,
) -> str:
# Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
@ -377,20 +376,20 @@ def translate_args_dispatcher_to_cpp(
output_bindings.append(binding)
return output_bindings
disp_sig = sig
disp_bindings = disp_sig.arguments()
src_bindings = list(sig.arguments())
goal_bindings = list(cpp_sig.arguments())
# When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType,
# get memory_format bindings of dispatcher signature to have the same NCType as well
for arg in cpp_sig.arguments():
for arg in goal_bindings:
if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format:
disp_bindings = add_spl_memory_format_binding(disp_sig.arguments())
src_bindings = add_spl_memory_format_binding(src_bindings)
break
exprs = translate(disp_bindings, cpp_sig.arguments())
exprs = translate(src_bindings, goal_bindings)
return ", ".join(a.expr for a in exprs)
def generate_static_dispatch_backend_call(
sig: DispatcherSignature,
sig: Union[CppSignature, DispatcherSignature],
f: NativeFunction,
backend_index: BackendIndex,
) -> str:
@ -403,7 +402,7 @@ def generate_static_dispatch_backend_call(
cpp_sig = cpp_sigs.signature
assert cpp_sig is not None
name = cpp_sig.name()
exprs = translate_args_dispatcher_to_cpp(sig, cpp_sig, f)
exprs = translate_args(sig, cpp_sig)
backend_metadata = backend_index.get_kernel(f)
kernel_ns = (
backend_metadata.cpp_namespace
@ -415,7 +414,7 @@ def generate_static_dispatch_backend_call(
def generate_static_dispatch_fallback_call(
sig: DispatcherSignature,
sig: Union[CppSignature, DispatcherSignature],
f: NativeFunction,
backend_indices: List[BackendIndex],
) -> str:
@ -428,7 +427,7 @@ def generate_static_dispatch_fallback_call(
cpp_sig = cpp_sigs.signature
assert cpp_sig is not None
name = cpp_sig.name()
exprs = translate_args_dispatcher_to_cpp(sig, cpp_sig, f)
exprs = translate_args(sig, cpp_sig)
ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "")
if f.has_composite_explicit_autograd_kernel:
return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});"
@ -444,10 +443,20 @@ def generate_static_dispatch_fallback_call(
def static_dispatch(
sig: DispatcherSignature,
sig: Union[CppSignature, DispatcherSignature],
f: NativeFunction,
backend_indices: List[BackendIndex],
) -> str:
"""
For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one
backends exsit, fallback to static dispatch by determining dispatch key from inputs.
Arguments:
sig: A CppSignature or DispatcherSignature for this native function we want to use.
f: NativeFunction to generate static dispatch.
backend_indices: All available backends.
Return:
C++ code to call backend-specific functions, e.g., "return at::cpu::add(self, other, scale);"
"""
if len(backend_indices) == 0 or f.manual_kernel_registration:
return ""