[inductor][4/N] triton support post-#5512, fix constexpr signatures (#145583)

Prior to this PR, constexprs were appearing in signatures as `{.. "XBLOCK : tl.constexpr": "constexpr"}` when they really should appear as `{.. "XBLOCK": "constexpr"}`.

This PR represents the argument names as ArgName objects, which can optionally be marked as constexpr.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145583
Approved by: https://github.com/jansel
This commit is contained in:
David Berard 2025-01-28 12:02:04 -08:00 committed by PyTorch MergeBot
parent 3f77002b96
commit 2e8c080ab1
10 changed files with 68 additions and 28 deletions

View file

@ -57,6 +57,7 @@ from torch._inductor.utils import (
run_and_get_kernels,
run_and_get_triton_code,
run_fw_bw_and_get_code,
triton_version_uses_attrs_dict,
)
from torch._inductor.virtualized import V
from torch._prims_common import is_integer_dtype
@ -13601,6 +13602,23 @@ if HAS_GPU and not TEST_WITH_ASAN:
r"reinterpret_tensor\(.*, \(1024, 50257\).*# reuse"
).run(code[1])
@unittest.skipIf(
not triton_version_uses_attrs_dict(),
"Test only applies to newer triton versions",
)
def test_triton_attrs_dict_constexpr_signature(self):
def fn(x):
return x.sin()
fn_c = torch.compile(fn)
x = torch.rand(16, device="cuda")
_, code = run_and_get_code(fn_c, x)
FileCheck().check("triton_meta").check("'signature':").check(
"'XBLOCK': 'constexpr'"
).run(code[0])
class RNNTest(TestCase):
device_type = GPU_TYPE

View file

@ -214,7 +214,7 @@ class DeviceCodegen:
cpp_wrapper_codegen: Optional[WrapperConstructor] = None
KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg, TMADescriptorArg]
KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg, TMADescriptorArg, ConstexprArg]
device_codegens: dict[str, DeviceCodegen] = {}
@ -1142,6 +1142,16 @@ class InplacedBuffer(NamedTuple):
other_names: list[str]
@dataclasses.dataclass
class ArgName:
name: str
# is_constexpr=True is used to attach a " : tl.constexpr" into the argument list
is_constexpr: bool = False
def full_name(self):
return f"{self.name}{' : tl.constexpr' if self.is_constexpr else ''}"
class KernelArgs:
@staticmethod
def _lookup(prefix: str, odict: dict[SymbolLike, str], name: SymbolLike) -> str:
@ -1346,15 +1356,17 @@ class KernelArgs:
assert not self.workspace_args, "Workspace not supported on CPU "
return arg_defs, call_args, arg_types
def python_argdefs(self):
arg_defs: list[str] = []
def python_argdefs(
self,
) -> tuple[list[ArgName], list[str], list[KernelArgType], list[torch.dtype]]:
arg_defs: list[ArgName] = []
call_args: list[str] = []
arg_types: list[torch.dtype] = []
precompile_args: list[Union[TensorArg, SizeArg, WorkspaceArg]] = []
precompile_args: list[KernelArgType] = []
for inplaced in unique(self.inplace_buffers.values()):
if self._buffer_is_marked_removed(inplaced):
continue
arg_defs.append(inplaced.inner_name)
arg_defs.append(ArgName(inplaced.inner_name))
call_args.append(inplaced.other_names[-1])
arg_types.append(V.graph.get_dtype(inplaced.other_names[-1]))
precompile_args.append(
@ -1369,7 +1381,7 @@ class KernelArgs:
):
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
continue
arg_defs.append(inner)
arg_defs.append(ArgName(inner))
call_args.append(outer)
arg_types.append(V.graph.get_dtype(outer))
precompile_args.append(
@ -1380,14 +1392,14 @@ class KernelArgs:
)
)
for outer, inner in self.sizevars.items():
arg_defs.append(inner)
arg_defs.append(ArgName(inner))
call_args.append(outer)
arg_types.append(type(outer)) # type: ignore[arg-type]
precompile_args.append(SizeArg(inner, outer))
if V.graph.wrapper_code:
V.graph.wrapper_code.ensure_size_computed(outer)
for arg in self.workspace_args:
arg_defs.append(arg.inner_name)
arg_defs.append(ArgName(arg.inner_name))
call_args.append(arg.outer_name)
precompile_args.append(arg)
arg_types.append(arg.dtype)

View file

@ -9,6 +9,7 @@ import sympy
from .. import ir
from ..select_algorithm import PartialRender
from ..virtualized import V
from .common import ArgName
from .cpp_gemm_template import CppGemmTemplate, GEMM_TEMPLATE
from .cpp_micro_gemm import LayoutType
from .cpp_template_kernel import CppTemplateKernel
@ -136,7 +137,7 @@ class CppBmmTemplate(CppGemmTemplate):
kernel: CppTemplateKernel,
function_name: str,
placeholder: str,
b_index: int,
b_index: str,
) -> str:
"""
Similar to 'def_kernel' in cpp_template_kernel, but instead of generating a function definition,
@ -150,8 +151,8 @@ class CppBmmTemplate(CppGemmTemplate):
arg_defs, call_args, _, _ = kernel.args.python_argdefs()
for i, buf in enumerate(call_args):
if buf == self.b_index:
arg_defs[i] = b_index
call = f"{function_name}({', '.join(arg_defs)});"
arg_defs[i] = ArgName(b_index)
call = f"{function_name}({', '.join(x.full_name() for x in arg_defs)});"
return call
assert placeholder not in kernel.render_hooks

View file

@ -26,7 +26,7 @@ log = logging.getLogger(__name__)
def get_kernel_argdefs(kernel):
arg_defs, _, _, _ = kernel.args.python_argdefs()
return arg_defs
return [x.name for x in arg_defs]
def _get_all_args(args_list, arg_types_list=None):

View file

@ -1026,8 +1026,9 @@ class SIMDKernel(Kernel):
for name in call_args
]
argdef_names = [x.name for x in argdefs]
msg = yellow_text(
f" param names {argdefs}\n buf names {call_args}\n strides {stride_order_list}"
f" param names {argdef_names}\n buf names {call_args}\n strides {stride_order_list}"
+ f"\n sizes {size_list}\n sources {source_list}\n"
)
log.warning(msg)

View file

@ -63,6 +63,7 @@ from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V
from ..wrapper_benchmark import get_kernel_category_by_source_code
from .block_analysis import BlockPatternMatcher
from .common import (
ArgName,
BackendFeature,
ConstexprArg,
CSE,
@ -3353,14 +3354,14 @@ class TritonKernel(SIMDKernel):
isinstance(arg, WorkspaceArg)
and arg.zero_mode == WorkspaceZeroMode.ZERO_ON_CALL
):
mutated_args.add(argname)
mutated_args.add(argname.name)
mutated_args = sorted(mutated_args)
for tree in self.active_range_trees():
sizearg = SizeArg(f"{tree.prefix}numel", tree.numel)
signature.append(sizearg)
argdefs.append(sizearg.name)
argdefs.append(ArgName(sizearg.name))
# constexpr version causes issues, see
# https://github.com/pytorch/torchdynamo/pull/1362
# triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint(
@ -3372,7 +3373,7 @@ class TritonKernel(SIMDKernel):
# new versions (but not old versions) of Triton need constexprs included in the signature
if triton_version_uses_attrs_dict():
signature.append(ConstexprArg(arg_name))
argdefs.append(f"{arg_name} : tl.constexpr")
argdefs.append(ArgName(arg_name, is_constexpr=True))
for tree in self.range_trees:
if tree.is_reduction and self.persistent_reduction:
@ -3428,7 +3429,7 @@ class TritonKernel(SIMDKernel):
# https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307
# https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index]
triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index]
triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index,union-attr]
self.triton_meta = triton_meta
@ -3481,7 +3482,7 @@ class TritonKernel(SIMDKernel):
"""
code.splice(heuristics_line)
code.writeline(
f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):"
f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(x.full_name() for x in argdefs)}):"
)
with code.indent():
self.codegen_static_numels(code)

View file

@ -19,6 +19,7 @@ from ..scheduler import BaseSchedulerNode
from ..utils import Placeholder, triton_version_uses_attrs_dict
from ..virtualized import V
from .common import (
ArgName,
ConstexprArg,
DeferredLine,
IndentedBuffer,
@ -649,7 +650,7 @@ class ComboKernel(Kernel):
size_hints: dict[str, int],
selected_kernel: TritonKernel,
signature: list[Any],
argdefs: list[str],
argdefs: list[ArgName],
pointwise_with_reduce: bool = False,
) -> str:
can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels)
@ -749,14 +750,16 @@ class ComboKernel(Kernel):
return [ConstexprArg(x) for x in block_names.keys()]
def add_numel_to_args(self, argdefs: list[str], signature: list[Any]) -> list[str]:
def add_numel_to_args(
self, argdefs: list[ArgName], signature: list[Any]
) -> list[ArgName]:
for num, sub_kernel in enumerate(self.sub_kernels):
for tree in sub_kernel.active_range_trees():
if not isinstance(tree.numel, (Integer, int)):
# only if it is a dynamic shape
sizearg = SizeArg(f"{tree.prefix}numel_{num}", tree.numel)
signature.append(sizearg)
argdefs.append(f"{tree.prefix}numel_{num}")
argdefs.append(ArgName(f"{tree.prefix}numel_{num}"))
self.dynamic_shape_args.append(f"{tree.prefix}numel_{num}")
return argdefs
@ -834,7 +837,7 @@ class ComboKernel(Kernel):
argdefs = self.add_numel_to_args(argdefs, signature)
block_args = self.get_block_args()
if self.enable_autotune:
argdefs.extend([f"{x.name}: tl.constexpr" for x in block_args])
argdefs.extend([ArgName(x.name, is_constexpr=True) for x in block_args])
if triton_version_uses_attrs_dict():
signature.extend(block_args)
@ -849,7 +852,7 @@ class ComboKernel(Kernel):
)
)
code.writeline(
f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):"
f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(x.full_name() for x in argdefs)}):"
)
with code.indent():

View file

@ -10,6 +10,7 @@ from ..runtime.hints import AttrsDescriptorWrapper
from ..utils import _type_of, expr_fits_within_32bit, triton_version_uses_attrs_dict
from ..virtualized import V
from .common import (
ArgName,
ConstexprArg,
KernelArgType,
SizeArg,
@ -104,13 +105,13 @@ def signature_to_meta(
signature: list[KernelArgType],
*,
size_dtype: Optional[str],
argdefs: list[str],
argdefs: list[ArgName],
indices: Optional[list[int]] = None,
) -> dict[str, str]:
if indices is None:
indices = list(range(len(signature)))
return {
argdefs[i]: signature_of(arg, size_dtype=size_dtype)
argdefs[i].name: signature_of(arg, size_dtype=size_dtype)
for i, arg in zip(indices, signature)
}

View file

@ -46,6 +46,7 @@ from ..utils import (
)
from ..virtualized import V
from .common import (
ArgName,
CodeGen,
DeferredLine,
IndentedBuffer,
@ -1657,7 +1658,7 @@ class PythonWrapperCodegen(CodeGen):
signature,
size_dtype=None, # try to infer based on symints
indices=non_constant_indices,
argdefs=kernel.arg_names,
argdefs=[ArgName(x) for x in kernel.arg_names],
)
triton_meta: dict[str, Any] = {
"signature": triton_signature,

View file

@ -445,7 +445,7 @@ class TritonTemplateKernel(TritonKernel):
def hook():
# python_argdefs() cannot be run until after the rest of the template lazily adds more args
arg_defs, *_ = self.args.python_argdefs()
return f"{', '.join(arg_defs)}"
return f"{', '.join(x.full_name() for x in arg_defs)}"
self.render_hooks["<ARGDEFS>"] = hook
return "<ARGDEFS>"
@ -515,7 +515,9 @@ class TritonTemplateKernel(TritonKernel):
code = IndentedBuffer()
code.splice(gen_common_triton_imports())
code.splice(self.jit_lines())
code.writeline(f"def {self.kernel_name}({', '.join(arg_defs)}):")
code.writeline(
f"def {self.kernel_name}({', '.join(x.full_name() for x in arg_defs)}):"
)
with code.indent():
code.splice(self.defines)
code.splice(renames.getvalue())