mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[inductor][4/N] triton support post-#5512, fix constexpr signatures (#145583)
Prior to this PR, constexprs were appearing in signatures as `{.. "XBLOCK : tl.constexpr": "constexpr"}` when they really should appear as `{.. "XBLOCK": "constexpr"}`.
This PR represents the argument names as ArgName objects, which can optionally be marked as constexpr.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145583
Approved by: https://github.com/jansel
This commit is contained in:
parent
3f77002b96
commit
2e8c080ab1
10 changed files with 68 additions and 28 deletions
|
|
@ -57,6 +57,7 @@ from torch._inductor.utils import (
|
|||
run_and_get_kernels,
|
||||
run_and_get_triton_code,
|
||||
run_fw_bw_and_get_code,
|
||||
triton_version_uses_attrs_dict,
|
||||
)
|
||||
from torch._inductor.virtualized import V
|
||||
from torch._prims_common import is_integer_dtype
|
||||
|
|
@ -13601,6 +13602,23 @@ if HAS_GPU and not TEST_WITH_ASAN:
|
|||
r"reinterpret_tensor\(.*, \(1024, 50257\).*# reuse"
|
||||
).run(code[1])
|
||||
|
||||
@unittest.skipIf(
|
||||
not triton_version_uses_attrs_dict(),
|
||||
"Test only applies to newer triton versions",
|
||||
)
|
||||
def test_triton_attrs_dict_constexpr_signature(self):
|
||||
def fn(x):
|
||||
return x.sin()
|
||||
|
||||
fn_c = torch.compile(fn)
|
||||
x = torch.rand(16, device="cuda")
|
||||
|
||||
_, code = run_and_get_code(fn_c, x)
|
||||
|
||||
FileCheck().check("triton_meta").check("'signature':").check(
|
||||
"'XBLOCK': 'constexpr'"
|
||||
).run(code[0])
|
||||
|
||||
class RNNTest(TestCase):
|
||||
device_type = GPU_TYPE
|
||||
|
||||
|
|
|
|||
|
|
@ -214,7 +214,7 @@ class DeviceCodegen:
|
|||
cpp_wrapper_codegen: Optional[WrapperConstructor] = None
|
||||
|
||||
|
||||
KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg, TMADescriptorArg]
|
||||
KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg, TMADescriptorArg, ConstexprArg]
|
||||
|
||||
device_codegens: dict[str, DeviceCodegen] = {}
|
||||
|
||||
|
|
@ -1142,6 +1142,16 @@ class InplacedBuffer(NamedTuple):
|
|||
other_names: list[str]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ArgName:
|
||||
name: str
|
||||
# is_constexpr=True is used to attach a " : tl.constexpr" into the argument list
|
||||
is_constexpr: bool = False
|
||||
|
||||
def full_name(self):
|
||||
return f"{self.name}{' : tl.constexpr' if self.is_constexpr else ''}"
|
||||
|
||||
|
||||
class KernelArgs:
|
||||
@staticmethod
|
||||
def _lookup(prefix: str, odict: dict[SymbolLike, str], name: SymbolLike) -> str:
|
||||
|
|
@ -1346,15 +1356,17 @@ class KernelArgs:
|
|||
assert not self.workspace_args, "Workspace not supported on CPU "
|
||||
return arg_defs, call_args, arg_types
|
||||
|
||||
def python_argdefs(self):
|
||||
arg_defs: list[str] = []
|
||||
def python_argdefs(
|
||||
self,
|
||||
) -> tuple[list[ArgName], list[str], list[KernelArgType], list[torch.dtype]]:
|
||||
arg_defs: list[ArgName] = []
|
||||
call_args: list[str] = []
|
||||
arg_types: list[torch.dtype] = []
|
||||
precompile_args: list[Union[TensorArg, SizeArg, WorkspaceArg]] = []
|
||||
precompile_args: list[KernelArgType] = []
|
||||
for inplaced in unique(self.inplace_buffers.values()):
|
||||
if self._buffer_is_marked_removed(inplaced):
|
||||
continue
|
||||
arg_defs.append(inplaced.inner_name)
|
||||
arg_defs.append(ArgName(inplaced.inner_name))
|
||||
call_args.append(inplaced.other_names[-1])
|
||||
arg_types.append(V.graph.get_dtype(inplaced.other_names[-1]))
|
||||
precompile_args.append(
|
||||
|
|
@ -1369,7 +1381,7 @@ class KernelArgs:
|
|||
):
|
||||
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
|
||||
continue
|
||||
arg_defs.append(inner)
|
||||
arg_defs.append(ArgName(inner))
|
||||
call_args.append(outer)
|
||||
arg_types.append(V.graph.get_dtype(outer))
|
||||
precompile_args.append(
|
||||
|
|
@ -1380,14 +1392,14 @@ class KernelArgs:
|
|||
)
|
||||
)
|
||||
for outer, inner in self.sizevars.items():
|
||||
arg_defs.append(inner)
|
||||
arg_defs.append(ArgName(inner))
|
||||
call_args.append(outer)
|
||||
arg_types.append(type(outer)) # type: ignore[arg-type]
|
||||
precompile_args.append(SizeArg(inner, outer))
|
||||
if V.graph.wrapper_code:
|
||||
V.graph.wrapper_code.ensure_size_computed(outer)
|
||||
for arg in self.workspace_args:
|
||||
arg_defs.append(arg.inner_name)
|
||||
arg_defs.append(ArgName(arg.inner_name))
|
||||
call_args.append(arg.outer_name)
|
||||
precompile_args.append(arg)
|
||||
arg_types.append(arg.dtype)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import sympy
|
|||
from .. import ir
|
||||
from ..select_algorithm import PartialRender
|
||||
from ..virtualized import V
|
||||
from .common import ArgName
|
||||
from .cpp_gemm_template import CppGemmTemplate, GEMM_TEMPLATE
|
||||
from .cpp_micro_gemm import LayoutType
|
||||
from .cpp_template_kernel import CppTemplateKernel
|
||||
|
|
@ -136,7 +137,7 @@ class CppBmmTemplate(CppGemmTemplate):
|
|||
kernel: CppTemplateKernel,
|
||||
function_name: str,
|
||||
placeholder: str,
|
||||
b_index: int,
|
||||
b_index: str,
|
||||
) -> str:
|
||||
"""
|
||||
Similar to 'def_kernel' in cpp_template_kernel, but instead of generating a function definition,
|
||||
|
|
@ -150,8 +151,8 @@ class CppBmmTemplate(CppGemmTemplate):
|
|||
arg_defs, call_args, _, _ = kernel.args.python_argdefs()
|
||||
for i, buf in enumerate(call_args):
|
||||
if buf == self.b_index:
|
||||
arg_defs[i] = b_index
|
||||
call = f"{function_name}({', '.join(arg_defs)});"
|
||||
arg_defs[i] = ArgName(b_index)
|
||||
call = f"{function_name}({', '.join(x.full_name() for x in arg_defs)});"
|
||||
return call
|
||||
|
||||
assert placeholder not in kernel.render_hooks
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ log = logging.getLogger(__name__)
|
|||
|
||||
def get_kernel_argdefs(kernel):
|
||||
arg_defs, _, _, _ = kernel.args.python_argdefs()
|
||||
return arg_defs
|
||||
return [x.name for x in arg_defs]
|
||||
|
||||
|
||||
def _get_all_args(args_list, arg_types_list=None):
|
||||
|
|
|
|||
|
|
@ -1026,8 +1026,9 @@ class SIMDKernel(Kernel):
|
|||
for name in call_args
|
||||
]
|
||||
|
||||
argdef_names = [x.name for x in argdefs]
|
||||
msg = yellow_text(
|
||||
f" param names {argdefs}\n buf names {call_args}\n strides {stride_order_list}"
|
||||
f" param names {argdef_names}\n buf names {call_args}\n strides {stride_order_list}"
|
||||
+ f"\n sizes {size_list}\n sources {source_list}\n"
|
||||
)
|
||||
log.warning(msg)
|
||||
|
|
|
|||
|
|
@ -63,6 +63,7 @@ from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V
|
|||
from ..wrapper_benchmark import get_kernel_category_by_source_code
|
||||
from .block_analysis import BlockPatternMatcher
|
||||
from .common import (
|
||||
ArgName,
|
||||
BackendFeature,
|
||||
ConstexprArg,
|
||||
CSE,
|
||||
|
|
@ -3353,14 +3354,14 @@ class TritonKernel(SIMDKernel):
|
|||
isinstance(arg, WorkspaceArg)
|
||||
and arg.zero_mode == WorkspaceZeroMode.ZERO_ON_CALL
|
||||
):
|
||||
mutated_args.add(argname)
|
||||
mutated_args.add(argname.name)
|
||||
|
||||
mutated_args = sorted(mutated_args)
|
||||
|
||||
for tree in self.active_range_trees():
|
||||
sizearg = SizeArg(f"{tree.prefix}numel", tree.numel)
|
||||
signature.append(sizearg)
|
||||
argdefs.append(sizearg.name)
|
||||
argdefs.append(ArgName(sizearg.name))
|
||||
# constexpr version causes issues, see
|
||||
# https://github.com/pytorch/torchdynamo/pull/1362
|
||||
# triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint(
|
||||
|
|
@ -3372,7 +3373,7 @@ class TritonKernel(SIMDKernel):
|
|||
# new versions (but not old versions) of Triton need constexprs included in the signature
|
||||
if triton_version_uses_attrs_dict():
|
||||
signature.append(ConstexprArg(arg_name))
|
||||
argdefs.append(f"{arg_name} : tl.constexpr")
|
||||
argdefs.append(ArgName(arg_name, is_constexpr=True))
|
||||
|
||||
for tree in self.range_trees:
|
||||
if tree.is_reduction and self.persistent_reduction:
|
||||
|
|
@ -3428,7 +3429,7 @@ class TritonKernel(SIMDKernel):
|
|||
# https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307
|
||||
# https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
|
||||
for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index]
|
||||
triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index]
|
||||
triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index,union-attr]
|
||||
|
||||
self.triton_meta = triton_meta
|
||||
|
||||
|
|
@ -3481,7 +3482,7 @@ class TritonKernel(SIMDKernel):
|
|||
"""
|
||||
code.splice(heuristics_line)
|
||||
code.writeline(
|
||||
f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):"
|
||||
f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(x.full_name() for x in argdefs)}):"
|
||||
)
|
||||
with code.indent():
|
||||
self.codegen_static_numels(code)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from ..scheduler import BaseSchedulerNode
|
|||
from ..utils import Placeholder, triton_version_uses_attrs_dict
|
||||
from ..virtualized import V
|
||||
from .common import (
|
||||
ArgName,
|
||||
ConstexprArg,
|
||||
DeferredLine,
|
||||
IndentedBuffer,
|
||||
|
|
@ -649,7 +650,7 @@ class ComboKernel(Kernel):
|
|||
size_hints: dict[str, int],
|
||||
selected_kernel: TritonKernel,
|
||||
signature: list[Any],
|
||||
argdefs: list[str],
|
||||
argdefs: list[ArgName],
|
||||
pointwise_with_reduce: bool = False,
|
||||
) -> str:
|
||||
can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels)
|
||||
|
|
@ -749,14 +750,16 @@ class ComboKernel(Kernel):
|
|||
|
||||
return [ConstexprArg(x) for x in block_names.keys()]
|
||||
|
||||
def add_numel_to_args(self, argdefs: list[str], signature: list[Any]) -> list[str]:
|
||||
def add_numel_to_args(
|
||||
self, argdefs: list[ArgName], signature: list[Any]
|
||||
) -> list[ArgName]:
|
||||
for num, sub_kernel in enumerate(self.sub_kernels):
|
||||
for tree in sub_kernel.active_range_trees():
|
||||
if not isinstance(tree.numel, (Integer, int)):
|
||||
# only if it is a dynamic shape
|
||||
sizearg = SizeArg(f"{tree.prefix}numel_{num}", tree.numel)
|
||||
signature.append(sizearg)
|
||||
argdefs.append(f"{tree.prefix}numel_{num}")
|
||||
argdefs.append(ArgName(f"{tree.prefix}numel_{num}"))
|
||||
self.dynamic_shape_args.append(f"{tree.prefix}numel_{num}")
|
||||
return argdefs
|
||||
|
||||
|
|
@ -834,7 +837,7 @@ class ComboKernel(Kernel):
|
|||
argdefs = self.add_numel_to_args(argdefs, signature)
|
||||
block_args = self.get_block_args()
|
||||
if self.enable_autotune:
|
||||
argdefs.extend([f"{x.name}: tl.constexpr" for x in block_args])
|
||||
argdefs.extend([ArgName(x.name, is_constexpr=True) for x in block_args])
|
||||
if triton_version_uses_attrs_dict():
|
||||
signature.extend(block_args)
|
||||
|
||||
|
|
@ -849,7 +852,7 @@ class ComboKernel(Kernel):
|
|||
)
|
||||
)
|
||||
code.writeline(
|
||||
f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):"
|
||||
f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(x.full_name() for x in argdefs)}):"
|
||||
)
|
||||
|
||||
with code.indent():
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from ..runtime.hints import AttrsDescriptorWrapper
|
|||
from ..utils import _type_of, expr_fits_within_32bit, triton_version_uses_attrs_dict
|
||||
from ..virtualized import V
|
||||
from .common import (
|
||||
ArgName,
|
||||
ConstexprArg,
|
||||
KernelArgType,
|
||||
SizeArg,
|
||||
|
|
@ -104,13 +105,13 @@ def signature_to_meta(
|
|||
signature: list[KernelArgType],
|
||||
*,
|
||||
size_dtype: Optional[str],
|
||||
argdefs: list[str],
|
||||
argdefs: list[ArgName],
|
||||
indices: Optional[list[int]] = None,
|
||||
) -> dict[str, str]:
|
||||
if indices is None:
|
||||
indices = list(range(len(signature)))
|
||||
return {
|
||||
argdefs[i]: signature_of(arg, size_dtype=size_dtype)
|
||||
argdefs[i].name: signature_of(arg, size_dtype=size_dtype)
|
||||
for i, arg in zip(indices, signature)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ from ..utils import (
|
|||
)
|
||||
from ..virtualized import V
|
||||
from .common import (
|
||||
ArgName,
|
||||
CodeGen,
|
||||
DeferredLine,
|
||||
IndentedBuffer,
|
||||
|
|
@ -1657,7 +1658,7 @@ class PythonWrapperCodegen(CodeGen):
|
|||
signature,
|
||||
size_dtype=None, # try to infer based on symints
|
||||
indices=non_constant_indices,
|
||||
argdefs=kernel.arg_names,
|
||||
argdefs=[ArgName(x) for x in kernel.arg_names],
|
||||
)
|
||||
triton_meta: dict[str, Any] = {
|
||||
"signature": triton_signature,
|
||||
|
|
|
|||
|
|
@ -445,7 +445,7 @@ class TritonTemplateKernel(TritonKernel):
|
|||
def hook():
|
||||
# python_argdefs() cannot be run until after the rest of the template lazily adds more args
|
||||
arg_defs, *_ = self.args.python_argdefs()
|
||||
return f"{', '.join(arg_defs)}"
|
||||
return f"{', '.join(x.full_name() for x in arg_defs)}"
|
||||
|
||||
self.render_hooks["<ARGDEFS>"] = hook
|
||||
return "<ARGDEFS>"
|
||||
|
|
@ -515,7 +515,9 @@ class TritonTemplateKernel(TritonKernel):
|
|||
code = IndentedBuffer()
|
||||
code.splice(gen_common_triton_imports())
|
||||
code.splice(self.jit_lines())
|
||||
code.writeline(f"def {self.kernel_name}({', '.join(arg_defs)}):")
|
||||
code.writeline(
|
||||
f"def {self.kernel_name}({', '.join(x.full_name() for x in arg_defs)}):"
|
||||
)
|
||||
with code.indent():
|
||||
code.splice(self.defines)
|
||||
code.splice(renames.getvalue())
|
||||
|
|
|
|||
Loading…
Reference in a new issue