From 41865bd8ed027228a37cb17eb5eefe8cb46ac4da Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Thu, 16 Feb 2023 02:43:14 +0000 Subject: [PATCH] [executorch] Add RuntimeContext to generated C++ API Signature (#94570) Summary: Pass runtime context all the way to kernel level. RegisterCodegenUnboxedKernels.cpp: ``` static Operator operators_to_register[] = { Operator( "aten::add.out", [](torch::executor::RuntimeContext & context, EValue** stack) { EValue& self = *stack[0]; EValue& other = *stack[1]; EValue& alpha = *stack[2]; EValue& out = *stack[3]; const torch::executor::Tensor & self_base = self.to(); const torch::executor::Tensor & other_base = other.to(); const torch::executor::Scalar & alpha_base = alpha.to(); torch::executor::Tensor & out_base = out.to(); EXECUTORCH_SCOPE_PROF("native_call_add.out"); torch::executor::aten::add_outf(context, self_base, other_base, alpha_base, out_base); } ), } ``` Functions.h ``` // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) TORCH_API inline at::Tensor & add_outf(torch::executor::RuntimeContext & context, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) { return torch::executor::native::add_out(self, other, alpha, out); } ``` Test Plan: TBD Differential Revision: D41325633 Pull Request resolved: https://github.com/pytorch/pytorch/pull/94570 Approved by: https://github.com/cccclai --- test/edge/RuntimeContext.h | 22 ++++++++ test/edge/operator_registry.h | 3 +- test/edge/test_operator_registration.cpp | 6 ++- tools/test/test_executorch_gen.py | 8 +-- tools/test/test_executorch_signatures.py | 58 +++++++++++++++++++++ torchgen/executorch/api/types/signatures.py | 17 +++--- torchgen/executorch/api/types/types.py | 26 ++++++++- torchgen/gen_executorch.py | 18 +++++-- 8 files changed, 139 insertions(+), 19 deletions(-) create mode 100644 test/edge/RuntimeContext.h create mode 100644 tools/test/test_executorch_signatures.py diff --git a/test/edge/RuntimeContext.h b/test/edge/RuntimeContext.h new file mode 100644 index 00000000000..5fa0e95707a --- /dev/null +++ b/test/edge/RuntimeContext.h @@ -0,0 +1,22 @@ +#pragma once + +namespace torch { +namespace executor { + +/** + * Bucket type abstraction that contains many elements of runtime state that + * a kernel author may want available, but would otherwise be unable to access. + * + * Forwarded along to all operators when running in lean mode. + * NOTE: Will not be forwarded to operators if running in ATen mode + * as those operators do not expect to receive a RuntimeContext and would not + * use it. + * + * This includes things like setting an error state, a scratch allocator for + * operators that need more then constant space, and a TensorResizer for dynamic + * shape tensors allowing programs to be more flexible with Tensor shape. + */ +class RuntimeContext {}; + +} // namespace executor +} // namespace torch diff --git a/test/edge/operator_registry.h b/test/edge/operator_registry.h index dee0b50c2a5..01b8d2374bc 100644 --- a/test/edge/operator_registry.h +++ b/test/edge/operator_registry.h @@ -4,13 +4,14 @@ #include #include "Evalue.h" +#include "RuntimeContext.h" #include #include namespace torch { namespace executor { -using OpFunction = std::function; +using OpFunction = std::function; template using ArrayRef = at::ArrayRef; diff --git a/test/edge/test_operator_registration.cpp b/test/edge/test_operator_registration.cpp index 89aed23df28..905c5de4c8f 100644 --- a/test/edge/test_operator_registration.cpp +++ b/test/edge/test_operator_registration.cpp @@ -18,7 +18,8 @@ TEST(OperatorRegistrationTest, Add) { for (size_t i = 0; i < 4; i++) { kernel_values[i] = &values[i]; } - op(kernel_values); + RuntimeContext context{}; + op(context, kernel_values); at::Tensor expected = at::ones({2, 3}); expected = at::fill(expected, 2); ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor())); @@ -39,7 +40,8 @@ TEST(OperatorRegistrationTest, CustomAdd3) { for (size_t i = 0; i < 4; i++) { kernel_values[i] = &values[i]; } - op(kernel_values); + RuntimeContext context{}; + op(context, kernel_values); at::Tensor expected = at::ones({2, 3}); expected = at::fill(expected, 3); ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor())); diff --git a/tools/test/test_executorch_gen.py b/tools/test/test_executorch_gen.py index 28f9516079c..25bd0197347 100644 --- a/tools/test/test_executorch_gen.py +++ b/tools/test/test_executorch_gen.py @@ -181,8 +181,8 @@ class TestGenFunctionsDeclarations(unittest.TestCase): namespace custom_1 { // custom_1::op_1() -> bool -TORCH_API inline bool op_1() { - return ::at::native::kernel_1(); +TORCH_API inline bool op_1(torch::executor::RuntimeContext & context) { + return ::at::native::kernel_1(context); } } // namespace custom_1 @@ -195,8 +195,8 @@ TORCH_API inline bool op_1() { namespace custom_2 { // custom_2::op_2() -> bool -TORCH_API inline bool op_2() { - return ::at::native::kernel_2(); +TORCH_API inline bool op_2(torch::executor::RuntimeContext & context) { + return ::at::native::kernel_2(context); } } // namespace custom_2 diff --git a/tools/test/test_executorch_signatures.py b/tools/test/test_executorch_signatures.py new file mode 100644 index 00000000000..6095fedc71f --- /dev/null +++ b/tools/test/test_executorch_signatures.py @@ -0,0 +1,58 @@ +import unittest + +from torchgen.executorch.api.types import ExecutorchCppSignature +from torchgen.local import parametrize +from torchgen.model import Location, NativeFunction + +DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml( + {"func": "foo.out(Tensor input, *, Tensor(a!) out) -> Tensor(a!)"}, + loc=Location(__file__, 1), + valid_tags=set(), +) + + +class ExecutorchCppSignatureTest(unittest.TestCase): + def setUp(self) -> None: + self.sig = ExecutorchCppSignature.from_native_function(DEFAULT_NATIVE_FUNCTION) + + def test_runtime_signature_contains_runtime_context(self) -> None: + # test if `RuntimeContext` argument exists in `RuntimeSignature` + with parametrize( + use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False + ): + args = self.sig.arguments(include_context=True) + self.assertEquals(len(args), 3) + self.assertTrue(any(a.name == "context" for a in args)) + + def test_runtime_signature_does_not_contain_runtime_context(self) -> None: + # test if `RuntimeContext` argument is missing in `RuntimeSignature` + with parametrize( + use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False + ): + args = self.sig.arguments(include_context=False) + self.assertEquals(len(args), 2) + self.assertFalse(any(a.name == "context" for a in args)) + + def test_runtime_signature_declaration_correct(self) -> None: + with parametrize( + use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False + ): + decl = self.sig.decl(include_context=True) + self.assertEquals( + decl, + ( + "torch::executor::Tensor & foo_outf(" + "torch::executor::RuntimeContext & context, " + "const torch::executor::Tensor & input, " + "torch::executor::Tensor & out)" + ), + ) + no_context_decl = self.sig.decl(include_context=False) + self.assertEquals( + no_context_decl, + ( + "torch::executor::Tensor & foo_outf(" + "const torch::executor::Tensor & input, " + "torch::executor::Tensor & out)" + ), + ) diff --git a/torchgen/executorch/api/types/signatures.py b/torchgen/executorch/api/types/signatures.py index 10f2c9d36a5..d79a4521644 100644 --- a/torchgen/executorch/api/types/signatures.py +++ b/torchgen/executorch/api/types/signatures.py @@ -6,12 +6,15 @@ import torchgen.api.cpp as aten_cpp from torchgen.api.types import Binding, CType from torchgen.model import FunctionSchema, NativeFunction +from .types import contextArg + @dataclass(frozen=True) class ExecutorchCppSignature: """ - This signature is merely a CppSignature with Executorch types. The inline definition - of CppSignature is generated in Functions.h and it's used by unboxing functions. + This signature is merely a CppSignature with Executorch types (optionally contains + RuntimeContext as well). The inline definition of CppSignature is generated in Functions.h + and it's used by unboxing functions. """ # The schema this signature is derived from @@ -25,8 +28,8 @@ class ExecutorchCppSignature: # and need to avoid naming collisions. prefix: str = "" - def arguments(self) -> List[Binding]: - return et_cpp.arguments( + def arguments(self, *, include_context: bool = True) -> List[Binding]: + return ([contextArg] if include_context else []) + et_cpp.arguments( self.func.arguments, faithful=True, # always faithful, out argument at the end method=False, # method not supported @@ -39,8 +42,10 @@ class ExecutorchCppSignature: faithful_name_for_out_overloads=True, ) - def decl(self, name: Optional[str] = None) -> str: - args_str = ", ".join(a.decl() for a in self.arguments()) + def decl(self, name: Optional[str] = None, *, include_context: bool = True) -> str: + args_str = ", ".join( + a.decl() for a in self.arguments(include_context=include_context) + ) if name is None: name = self.name() return f"{self.returns_type().cpp_type()} {name}({args_str})" diff --git a/torchgen/executorch/api/types/types.py b/torchgen/executorch/api/types/types.py index d4217c0b945..f6775ca61b6 100644 --- a/torchgen/executorch/api/types/types.py +++ b/torchgen/executorch/api/types/types.py @@ -1,7 +1,18 @@ from dataclasses import dataclass from typing import Dict -from torchgen.api.types import BaseCppType, boolT, CType, doubleT, longT +from torchgen.api.types import ( + BaseCppType, + BaseCType, + Binding, + boolT, + CType, + doubleT, + Expr, + longT, + MutRefCType, + NamedCType, +) from torchgen.model import BaseTy halfT = BaseCppType("torch::executor", "Half") @@ -14,6 +25,19 @@ scalarT = BaseCppType("torch::executor", "Scalar") memoryFormatT = BaseCppType("torch::executor", "MemoryFormat") intArrayRefT = BaseCppType("torch::executor", "IntArrayRef") optionalT = BaseCppType("torch::executor", "optional") +contextT = BaseCppType("torch::executor", "RuntimeContext") + +contextExpr = Expr( + expr="context", + type=NamedCType(name="context", type=MutRefCType(BaseCType(contextT))), +) + +contextArg = Binding( + name="context", + nctype=contextExpr.type, + argument=None, # type: ignore[arg-type] + default=None, +) BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = { BaseTy.int: longT, diff --git a/torchgen/gen_executorch.py b/torchgen/gen_executorch.py index a7a820e774a..621d14d4c1c 100644 --- a/torchgen/gen_executorch.py +++ b/torchgen/gen_executorch.py @@ -17,7 +17,7 @@ from torchgen.executorch.api.custom_ops import ( ComputeNativeFunctionStub, gen_custom_ops_registration, ) -from torchgen.executorch.api.types import ExecutorchCppSignature +from torchgen.executorch.api.types import contextArg, ExecutorchCppSignature from torchgen.executorch.api.unboxing import Unboxing from torchgen.gen import ( get_custom_build_selector, @@ -149,14 +149,16 @@ class ComputeCodegenUnboxedKernels: ).most_faithful_signature() argument_type_gen = aten_cpp.argumenttype_type return_type_gen = aten_cpp.returns_type + arguments = sig.arguments() else: sig = ExecutorchCppSignature.from_native_function(f) argument_type_gen = et_cpp.argumenttype_type return_type_gen = et_cpp.returns_type + arguments = sig.arguments(include_context=False) # parse arguments into C++ code binding_list, code_list = Unboxing( argument_type_gen=argument_type_gen - ).convert_arguments(sig.arguments()) + ).convert_arguments(arguments) # for each C++ argument, generate the conversion code code_connector = "\n\t" @@ -185,11 +187,12 @@ class ComputeCodegenUnboxedKernels: return f""" Operator( "{f.namespace}::{f.func.name}", - [](EValue** stack) {{ + []({contextArg.defn()}, EValue** stack) {{ + {"(void)context;" if self.use_aten_lib else ""} {code_connector.join(code_list)} EXECUTORCH_SCOPE_PROF("native_call_{f.func.name}"); - {ret_prefix}torch::executor::{f.namespace}::{sig.name()}({args_str}); + {ret_prefix}torch::executor::{f.namespace}::{sig.name()}({"" if self.use_aten_lib else "context, "}{args_str}); {return_assignment} }} @@ -229,7 +232,12 @@ def compute_native_function_declaration( if metadata is None: return [] prefix = "static" if backend_index.external else "TORCH_API" - return [f"{prefix} {sig.decl(name=metadata.kernel)};"] + # for kernels in lean mode, we declare two versions, one with context and one without. + # In the end we will cleanup the unused one. + return [ + f"{prefix} {sig.decl(name=metadata.kernel)};", + f"{prefix} {sig.decl(name=metadata.kernel, include_context=False)};", + ] def gen_functions_declarations(