diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index bc89441e3bd..5e5cbf35baf 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -40,6 +40,7 @@ from .codegen.triton import ( ) from .codegen.triton_utils import config_of, signature_to_meta +from .codegen.wrapper import pexpr from .exc import CUDACompileError from .ir import ChoiceCaller, PrimitiveInfoType from .runtime.hints import DeviceProperties @@ -537,7 +538,7 @@ class TritonTemplateKernel(TritonKernel): meta = wrapper.add_meta_once(self.meta) grid_call = [ - texpr(V.graph.sizevars.simplify(s)) for s in self.call_sizes + pexpr(V.graph.sizevars.simplify(s)) for s in self.call_sizes ] + [meta] grid_call = f"{self.grid_fn.__module__}.{self.grid_fn.__name__}({', '.join(grid_call)})" wrapper.writeline(