mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[torchgen] Refactor static_dispatch to take in source signature (#84384)
Summary: Context: currently `static_dispatch` assumes that given a native function `f`, we always want to map from its `DispatchSignature` to its `CppSignature`. This assumption may not hold true for some use cases, where the source bindings may not come from its `DispatchSignature`. Here I'm changing the argument `sig: DispatcherSignature` to be `sig: Union[CppSignature, DispatcherSignature]`, also removes unused `f` Test Plan: Rely on added unit test. Differential Revision: D39192969 Pull Request resolved: https://github.com/pytorch/pytorch/pull/84384 Approved by: https://github.com/iseeyuan
This commit is contained in:
parent
c5a8946e40
commit
2765243cd5
2 changed files with 87 additions and 16 deletions
|
|
@ -9,10 +9,13 @@ import torchgen.model
|
|||
import yaml
|
||||
|
||||
from tools.autograd import gen_autograd_functions, load_derivatives
|
||||
from torchgen.api.types import CppSignatureGroup, DispatcherSignature
|
||||
from torchgen.context import native_function_manager
|
||||
from torchgen.gen import (
|
||||
get_native_function_declarations,
|
||||
get_native_function_schema_registrations,
|
||||
LineLoader,
|
||||
static_dispatch,
|
||||
)
|
||||
from torchgen.model import (
|
||||
BackendIndex,
|
||||
|
|
@ -408,6 +411,65 @@ class TestNativeFunctionGeneratrion(unittest.TestCase):
|
|||
self.assertEqual(backend_metadata.kernel, "op_2_out")
|
||||
|
||||
|
||||
# Test for static_dispatch
|
||||
class TestStaticDispatchGeneratrion(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.backend_indices: Dict[
|
||||
DispatchKey, Dict[OperatorName, BackendMetadata]
|
||||
] = defaultdict(dict)
|
||||
yaml_entry = """
|
||||
- func: op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: op
|
||||
"""
|
||||
es = yaml.load(yaml_entry, Loader=LineLoader)
|
||||
self.one_return_func, m = NativeFunction.from_yaml(
|
||||
es[0], loc=Location(__file__, 1), valid_tags=set()
|
||||
)
|
||||
|
||||
BackendIndex.grow_index(self.backend_indices, m)
|
||||
dispatch_key = DispatchKey.CompositeExplicitAutograd
|
||||
self.assertTrue(dispatch_key in self.backend_indices)
|
||||
self.indices = [
|
||||
BackendIndex(
|
||||
dispatch_key=dispatch_key,
|
||||
use_out_as_primary=True,
|
||||
external=False,
|
||||
device_guard=False,
|
||||
index=self.backend_indices[dispatch_key],
|
||||
)
|
||||
]
|
||||
|
||||
def test_op_with_1_backend_generates_static_dispatch(self) -> None:
|
||||
disp_sig = DispatcherSignature.from_schema(self.one_return_func.func)
|
||||
with native_function_manager(self.one_return_func):
|
||||
out = static_dispatch(
|
||||
sig=disp_sig,
|
||||
f=self.one_return_func,
|
||||
backend_indices=self.indices,
|
||||
)
|
||||
self.assertEqual(
|
||||
out, "return at::compositeexplicitautograd::op_out(out, self);"
|
||||
)
|
||||
|
||||
def test_op_with_cpp_sig_generates_static_dispatch(self) -> None:
|
||||
sig_group = CppSignatureGroup.from_native_function(
|
||||
self.one_return_func,
|
||||
method=False,
|
||||
fallback_binding=self.one_return_func.manual_cpp_binding,
|
||||
)
|
||||
# cpp signature puts out at the front
|
||||
with native_function_manager(self.one_return_func):
|
||||
out = static_dispatch(
|
||||
sig=sig_group.signature,
|
||||
f=self.one_return_func,
|
||||
backend_indices=self.indices,
|
||||
)
|
||||
self.assertEqual(
|
||||
out, "return at::compositeexplicitautograd::op_out(out, self);"
|
||||
)
|
||||
|
||||
|
||||
# Represents the most basic NativeFunction. Use dataclasses.replace()
|
||||
# to edit for use.
|
||||
DEFAULT_NATIVE_FUNCTION, _ = torchgen.model.NativeFunction.from_yaml(
|
||||
|
|
|
|||
|
|
@ -349,13 +349,12 @@ def static_dispatch_extra_headers(backends: List[BackendIndex]) -> List[str]:
|
|||
]
|
||||
|
||||
|
||||
# Translates arguments of a native function from DispatcherSignature form to CppSignature form with support for
|
||||
# supporting usecases even when there is a memory_format argument along with tensor_option arguments.
|
||||
# This usecase is not covered by tools.codegen.api.translate() yet as its application is limited to static dispatch
|
||||
def translate_args_dispatcher_to_cpp(
|
||||
sig: DispatcherSignature,
|
||||
# Translates arguments of `sig` to CppSignature bindings.
|
||||
# Note that we have a special case for `memory_format` argument and this case is not covered by
|
||||
# tools.codegen.api.translate() yet as its application is limited to static dispatch.
|
||||
def translate_args(
|
||||
sig: Union[CppSignature, DispatcherSignature],
|
||||
cpp_sig: CppSignature,
|
||||
f: NativeFunction,
|
||||
) -> str:
|
||||
|
||||
# Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
|
||||
|
|
@ -377,20 +376,20 @@ def translate_args_dispatcher_to_cpp(
|
|||
output_bindings.append(binding)
|
||||
return output_bindings
|
||||
|
||||
disp_sig = sig
|
||||
disp_bindings = disp_sig.arguments()
|
||||
src_bindings = list(sig.arguments())
|
||||
goal_bindings = list(cpp_sig.arguments())
|
||||
# When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType,
|
||||
# get memory_format bindings of dispatcher signature to have the same NCType as well
|
||||
for arg in cpp_sig.arguments():
|
||||
for arg in goal_bindings:
|
||||
if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format:
|
||||
disp_bindings = add_spl_memory_format_binding(disp_sig.arguments())
|
||||
src_bindings = add_spl_memory_format_binding(src_bindings)
|
||||
break
|
||||
exprs = translate(disp_bindings, cpp_sig.arguments())
|
||||
exprs = translate(src_bindings, goal_bindings)
|
||||
return ", ".join(a.expr for a in exprs)
|
||||
|
||||
|
||||
def generate_static_dispatch_backend_call(
|
||||
sig: DispatcherSignature,
|
||||
sig: Union[CppSignature, DispatcherSignature],
|
||||
f: NativeFunction,
|
||||
backend_index: BackendIndex,
|
||||
) -> str:
|
||||
|
|
@ -403,7 +402,7 @@ def generate_static_dispatch_backend_call(
|
|||
cpp_sig = cpp_sigs.signature
|
||||
assert cpp_sig is not None
|
||||
name = cpp_sig.name()
|
||||
exprs = translate_args_dispatcher_to_cpp(sig, cpp_sig, f)
|
||||
exprs = translate_args(sig, cpp_sig)
|
||||
backend_metadata = backend_index.get_kernel(f)
|
||||
kernel_ns = (
|
||||
backend_metadata.cpp_namespace
|
||||
|
|
@ -415,7 +414,7 @@ def generate_static_dispatch_backend_call(
|
|||
|
||||
|
||||
def generate_static_dispatch_fallback_call(
|
||||
sig: DispatcherSignature,
|
||||
sig: Union[CppSignature, DispatcherSignature],
|
||||
f: NativeFunction,
|
||||
backend_indices: List[BackendIndex],
|
||||
) -> str:
|
||||
|
|
@ -428,7 +427,7 @@ def generate_static_dispatch_fallback_call(
|
|||
cpp_sig = cpp_sigs.signature
|
||||
assert cpp_sig is not None
|
||||
name = cpp_sig.name()
|
||||
exprs = translate_args_dispatcher_to_cpp(sig, cpp_sig, f)
|
||||
exprs = translate_args(sig, cpp_sig)
|
||||
ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "")
|
||||
if f.has_composite_explicit_autograd_kernel:
|
||||
return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});"
|
||||
|
|
@ -444,10 +443,20 @@ def generate_static_dispatch_fallback_call(
|
|||
|
||||
|
||||
def static_dispatch(
|
||||
sig: DispatcherSignature,
|
||||
sig: Union[CppSignature, DispatcherSignature],
|
||||
f: NativeFunction,
|
||||
backend_indices: List[BackendIndex],
|
||||
) -> str:
|
||||
"""
|
||||
For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one
|
||||
backends exsit, fallback to static dispatch by determining dispatch key from inputs.
|
||||
Arguments:
|
||||
sig: A CppSignature or DispatcherSignature for this native function we want to use.
|
||||
f: NativeFunction to generate static dispatch.
|
||||
backend_indices: All available backends.
|
||||
Return:
|
||||
C++ code to call backend-specific functions, e.g., "return at::cpu::add(self, other, scale);"
|
||||
"""
|
||||
if len(backend_indices) == 0 or f.manual_kernel_registration:
|
||||
return ""
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue