From d3ad84c38f5b06bd0278bec6dc6a2d42e12fa97e Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 5 Jun 2024 13:06:01 -0400 Subject: [PATCH] Use pexpr, not texpr in Triton launch codegen (#128038) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/128038 Approved by: https://github.com/Skylion007 --- torch/_inductor/select_algorithm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index bc89441e3bd..5e5cbf35baf 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -40,6 +40,7 @@ from .codegen.triton import ( ) from .codegen.triton_utils import config_of, signature_to_meta +from .codegen.wrapper import pexpr from .exc import CUDACompileError from .ir import ChoiceCaller, PrimitiveInfoType from .runtime.hints import DeviceProperties @@ -537,7 +538,7 @@ class TritonTemplateKernel(TritonKernel): meta = wrapper.add_meta_once(self.meta) grid_call = [ - texpr(V.graph.sizevars.simplify(s)) for s in self.call_sizes + pexpr(V.graph.sizevars.simplify(s)) for s in self.call_sizes ] + [meta] grid_call = f"{self.grid_fn.__module__}.{self.grid_fn.__name__}({', '.join(grid_call)})" wrapper.writeline(