Use pexpr, not texpr in Triton launch codegen (#128038)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128038
Approved by: https://github.com/Skylion007
This commit is contained in:
Edward Z. Yang 2024-06-05 13:06:01 -04:00 committed by PyTorch MergeBot
parent 8bcebc8dae
commit d3ad84c38f

View file

@ -40,6 +40,7 @@ from .codegen.triton import (
)
from .codegen.triton_utils import config_of, signature_to_meta
from .codegen.wrapper import pexpr
from .exc import CUDACompileError
from .ir import ChoiceCaller, PrimitiveInfoType
from .runtime.hints import DeviceProperties
@ -537,7 +538,7 @@ class TritonTemplateKernel(TritonKernel):
meta = wrapper.add_meta_once(self.meta)
grid_call = [
texpr(V.graph.sizevars.simplify(s)) for s in self.call_sizes
pexpr(V.graph.sizevars.simplify(s)) for s in self.call_sizes
] + [meta]
grid_call = f"{self.grid_fn.__module__}.{self.grid_fn.__name__}({', '.join(grid_call)})"
wrapper.writeline(