diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 0c9e08f8d42..bb8cc3d09e0 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -57,6 +57,7 @@ from torch._inductor.utils import ( run_and_get_kernels, run_and_get_triton_code, run_fw_bw_and_get_code, + triton_version_uses_attrs_dict, ) from torch._inductor.virtualized import V from torch._prims_common import is_integer_dtype @@ -13601,6 +13602,23 @@ if HAS_GPU and not TEST_WITH_ASAN: r"reinterpret_tensor\(.*, \(1024, 50257\).*# reuse" ).run(code[1]) + @unittest.skipIf( + not triton_version_uses_attrs_dict(), + "Test only applies to newer triton versions", + ) + def test_triton_attrs_dict_constexpr_signature(self): + def fn(x): + return x.sin() + + fn_c = torch.compile(fn) + x = torch.rand(16, device="cuda") + + _, code = run_and_get_code(fn_c, x) + + FileCheck().check("triton_meta").check("'signature':").check( + "'XBLOCK': 'constexpr'" + ).run(code[0]) + class RNNTest(TestCase): device_type = GPU_TYPE diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 7f0b2d0bcd2..e91c0ab897a 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -214,7 +214,7 @@ class DeviceCodegen: cpp_wrapper_codegen: Optional[WrapperConstructor] = None -KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg, TMADescriptorArg] +KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg, TMADescriptorArg, ConstexprArg] device_codegens: dict[str, DeviceCodegen] = {} @@ -1142,6 +1142,16 @@ class InplacedBuffer(NamedTuple): other_names: list[str] +@dataclasses.dataclass +class ArgName: + name: str + # is_constexpr=True is used to attach a " : tl.constexpr" into the argument list + is_constexpr: bool = False + + def full_name(self): + return f"{self.name}{' : tl.constexpr' if self.is_constexpr else ''}" + + class KernelArgs: @staticmethod def _lookup(prefix: str, odict: dict[SymbolLike, str], name: SymbolLike) -> str: @@ -1346,15 +1356,17 @@ class KernelArgs: assert not self.workspace_args, "Workspace not supported on CPU " return arg_defs, call_args, arg_types - def python_argdefs(self): - arg_defs: list[str] = [] + def python_argdefs( + self, + ) -> tuple[list[ArgName], list[str], list[KernelArgType], list[torch.dtype]]: + arg_defs: list[ArgName] = [] call_args: list[str] = [] arg_types: list[torch.dtype] = [] - precompile_args: list[Union[TensorArg, SizeArg, WorkspaceArg]] = [] + precompile_args: list[KernelArgType] = [] for inplaced in unique(self.inplace_buffers.values()): if self._buffer_is_marked_removed(inplaced): continue - arg_defs.append(inplaced.inner_name) + arg_defs.append(ArgName(inplaced.inner_name)) call_args.append(inplaced.other_names[-1]) arg_types.append(V.graph.get_dtype(inplaced.other_names[-1])) precompile_args.append( @@ -1369,7 +1381,7 @@ class KernelArgs: ): if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): continue - arg_defs.append(inner) + arg_defs.append(ArgName(inner)) call_args.append(outer) arg_types.append(V.graph.get_dtype(outer)) precompile_args.append( @@ -1380,14 +1392,14 @@ class KernelArgs: ) ) for outer, inner in self.sizevars.items(): - arg_defs.append(inner) + arg_defs.append(ArgName(inner)) call_args.append(outer) arg_types.append(type(outer)) # type: ignore[arg-type] precompile_args.append(SizeArg(inner, outer)) if V.graph.wrapper_code: V.graph.wrapper_code.ensure_size_computed(outer) for arg in self.workspace_args: - arg_defs.append(arg.inner_name) + arg_defs.append(ArgName(arg.inner_name)) call_args.append(arg.outer_name) precompile_args.append(arg) arg_types.append(arg.dtype) diff --git a/torch/_inductor/codegen/cpp_bmm_template.py b/torch/_inductor/codegen/cpp_bmm_template.py index f80687ef4bf..78154ca2c9c 100644 --- a/torch/_inductor/codegen/cpp_bmm_template.py +++ b/torch/_inductor/codegen/cpp_bmm_template.py @@ -9,6 +9,7 @@ import sympy from .. import ir from ..select_algorithm import PartialRender from ..virtualized import V +from .common import ArgName from .cpp_gemm_template import CppGemmTemplate, GEMM_TEMPLATE from .cpp_micro_gemm import LayoutType from .cpp_template_kernel import CppTemplateKernel @@ -136,7 +137,7 @@ class CppBmmTemplate(CppGemmTemplate): kernel: CppTemplateKernel, function_name: str, placeholder: str, - b_index: int, + b_index: str, ) -> str: """ Similar to 'def_kernel' in cpp_template_kernel, but instead of generating a function definition, @@ -150,8 +151,8 @@ class CppBmmTemplate(CppGemmTemplate): arg_defs, call_args, _, _ = kernel.args.python_argdefs() for i, buf in enumerate(call_args): if buf == self.b_index: - arg_defs[i] = b_index - call = f"{function_name}({', '.join(arg_defs)});" + arg_defs[i] = ArgName(b_index) + call = f"{function_name}({', '.join(x.full_name() for x in arg_defs)});" return call assert placeholder not in kernel.render_hooks diff --git a/torch/_inductor/codegen/multi_kernel.py b/torch/_inductor/codegen/multi_kernel.py index ec494808367..59eba17ec24 100644 --- a/torch/_inductor/codegen/multi_kernel.py +++ b/torch/_inductor/codegen/multi_kernel.py @@ -26,7 +26,7 @@ log = logging.getLogger(__name__) def get_kernel_argdefs(kernel): arg_defs, _, _, _ = kernel.args.python_argdefs() - return arg_defs + return [x.name for x in arg_defs] def _get_all_args(args_list, arg_types_list=None): diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index a011fa82d43..251076cbf03 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -1026,8 +1026,9 @@ class SIMDKernel(Kernel): for name in call_args ] + argdef_names = [x.name for x in argdefs] msg = yellow_text( - f" param names {argdefs}\n buf names {call_args}\n strides {stride_order_list}" + f" param names {argdef_names}\n buf names {call_args}\n strides {stride_order_list}" + f"\n sizes {size_list}\n sources {source_list}\n" ) log.warning(msg) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 3257cf0496a..f2c28a90a8e 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -63,6 +63,7 @@ from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V from ..wrapper_benchmark import get_kernel_category_by_source_code from .block_analysis import BlockPatternMatcher from .common import ( + ArgName, BackendFeature, ConstexprArg, CSE, @@ -3353,14 +3354,14 @@ class TritonKernel(SIMDKernel): isinstance(arg, WorkspaceArg) and arg.zero_mode == WorkspaceZeroMode.ZERO_ON_CALL ): - mutated_args.add(argname) + mutated_args.add(argname.name) mutated_args = sorted(mutated_args) for tree in self.active_range_trees(): sizearg = SizeArg(f"{tree.prefix}numel", tree.numel) signature.append(sizearg) - argdefs.append(sizearg.name) + argdefs.append(ArgName(sizearg.name)) # constexpr version causes issues, see # https://github.com/pytorch/torchdynamo/pull/1362 # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint( @@ -3372,7 +3373,7 @@ class TritonKernel(SIMDKernel): # new versions (but not old versions) of Triton need constexprs included in the signature if triton_version_uses_attrs_dict(): signature.append(ConstexprArg(arg_name)) - argdefs.append(f"{arg_name} : tl.constexpr") + argdefs.append(ArgName(arg_name, is_constexpr=True)) for tree in self.range_trees: if tree.is_reduction and self.persistent_reduction: @@ -3428,7 +3429,7 @@ class TritonKernel(SIMDKernel): # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307 # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index] - triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index] + triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index,union-attr] self.triton_meta = triton_meta @@ -3481,7 +3482,7 @@ class TritonKernel(SIMDKernel): """ code.splice(heuristics_line) code.writeline( - f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):" + f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(x.full_name() for x in argdefs)}):" ) with code.indent(): self.codegen_static_numels(code) diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index 74be91c57b7..1220a892b2e 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -19,6 +19,7 @@ from ..scheduler import BaseSchedulerNode from ..utils import Placeholder, triton_version_uses_attrs_dict from ..virtualized import V from .common import ( + ArgName, ConstexprArg, DeferredLine, IndentedBuffer, @@ -649,7 +650,7 @@ class ComboKernel(Kernel): size_hints: dict[str, int], selected_kernel: TritonKernel, signature: list[Any], - argdefs: list[str], + argdefs: list[ArgName], pointwise_with_reduce: bool = False, ) -> str: can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels) @@ -749,14 +750,16 @@ class ComboKernel(Kernel): return [ConstexprArg(x) for x in block_names.keys()] - def add_numel_to_args(self, argdefs: list[str], signature: list[Any]) -> list[str]: + def add_numel_to_args( + self, argdefs: list[ArgName], signature: list[Any] + ) -> list[ArgName]: for num, sub_kernel in enumerate(self.sub_kernels): for tree in sub_kernel.active_range_trees(): if not isinstance(tree.numel, (Integer, int)): # only if it is a dynamic shape sizearg = SizeArg(f"{tree.prefix}numel_{num}", tree.numel) signature.append(sizearg) - argdefs.append(f"{tree.prefix}numel_{num}") + argdefs.append(ArgName(f"{tree.prefix}numel_{num}")) self.dynamic_shape_args.append(f"{tree.prefix}numel_{num}") return argdefs @@ -834,7 +837,7 @@ class ComboKernel(Kernel): argdefs = self.add_numel_to_args(argdefs, signature) block_args = self.get_block_args() if self.enable_autotune: - argdefs.extend([f"{x.name}: tl.constexpr" for x in block_args]) + argdefs.extend([ArgName(x.name, is_constexpr=True) for x in block_args]) if triton_version_uses_attrs_dict(): signature.extend(block_args) @@ -849,7 +852,7 @@ class ComboKernel(Kernel): ) ) code.writeline( - f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):" + f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(x.full_name() for x in argdefs)}):" ) with code.indent(): diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index 06e07fb1acb..78b97196338 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -10,6 +10,7 @@ from ..runtime.hints import AttrsDescriptorWrapper from ..utils import _type_of, expr_fits_within_32bit, triton_version_uses_attrs_dict from ..virtualized import V from .common import ( + ArgName, ConstexprArg, KernelArgType, SizeArg, @@ -104,13 +105,13 @@ def signature_to_meta( signature: list[KernelArgType], *, size_dtype: Optional[str], - argdefs: list[str], + argdefs: list[ArgName], indices: Optional[list[int]] = None, ) -> dict[str, str]: if indices is None: indices = list(range(len(signature))) return { - argdefs[i]: signature_of(arg, size_dtype=size_dtype) + argdefs[i].name: signature_of(arg, size_dtype=size_dtype) for i, arg in zip(indices, signature) } diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 508f325f09f..8ba1a08664c 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -46,6 +46,7 @@ from ..utils import ( ) from ..virtualized import V from .common import ( + ArgName, CodeGen, DeferredLine, IndentedBuffer, @@ -1657,7 +1658,7 @@ class PythonWrapperCodegen(CodeGen): signature, size_dtype=None, # try to infer based on symints indices=non_constant_indices, - argdefs=kernel.arg_names, + argdefs=[ArgName(x) for x in kernel.arg_names], ) triton_meta: dict[str, Any] = { "signature": triton_signature, diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 37d3cd1723f..e8efb9e2627 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -445,7 +445,7 @@ class TritonTemplateKernel(TritonKernel): def hook(): # python_argdefs() cannot be run until after the rest of the template lazily adds more args arg_defs, *_ = self.args.python_argdefs() - return f"{', '.join(arg_defs)}" + return f"{', '.join(x.full_name() for x in arg_defs)}" self.render_hooks[""] = hook return "" @@ -515,7 +515,9 @@ class TritonTemplateKernel(TritonKernel): code = IndentedBuffer() code.splice(gen_common_triton_imports()) code.splice(self.jit_lines()) - code.writeline(f"def {self.kernel_name}({', '.join(arg_defs)}):") + code.writeline( + f"def {self.kernel_name}({', '.join(x.full_name() for x in arg_defs)}):" + ) with code.indent(): code.splice(self.defines) code.splice(renames.getvalue())