From 04ce02182b6a4e90f80a906ef303f60af492a642 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 6 Feb 2025 14:56:59 -0800 Subject: [PATCH] [inductor] Use index_dtype (int32/int64 depending on size) for argmax accumulators (#146651) Pull Request resolved: https://github.com/pytorch/pytorch/pull/146651 Approved by: https://github.com/shunting314, https://github.com/eellison --- test/inductor/test_cooperative_reductions.py | 17 +++++++++++++++++ torch/_inductor/codegen/triton.py | 8 +++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_cooperative_reductions.py b/test/inductor/test_cooperative_reductions.py index 9cc9c830e01..f0d05fb8db8 100644 --- a/test/inductor/test_cooperative_reductions.py +++ b/test/inductor/test_cooperative_reductions.py @@ -255,6 +255,23 @@ class TestFixedConfigs(TestCase): ) ] self._check(fn, args, persistent=persistent, cfg=cfg) + args = [ + torch.stack( + [ + torch.tensor( + [0.0] * 150 + [float("inf")] * 150, + device="cuda", + dtype=torch.float32, + ), + torch.tensor( + [0.0] * 150 + [-float("inf")] * 150, + device="cuda", + dtype=torch.float32, + ), + ] + ) + ] + self._check(fn, args, persistent=persistent, cfg=cfg) @parametrize("persistent", [False, True]) @parametrize("rsplit", [32, 33]) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index b66073b4d87..c0898d13c26 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2545,9 +2545,10 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): if reduction_type in ("argmax", "argmin"): accumulator_index = f"_{result_var}_index" - long_max = torch.iinfo(torch.int64).max + index_dtype = self.features.select_index_dtype() self.body.writeline( - f"{accumulator_index} = tl.full({self.dense_size_str()}, {long_max}, tl.int64)" + f"{accumulator_index} = tl.full({self.dense_size_str()}, " + f"{torch.iinfo(index_dtype).max}, {self.dtype_to_str(index_dtype)})" ) root_op = {"argmax": "max", "argmin": "min"}[reduction_type] @@ -2609,8 +2610,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): peer_val = self.codegen_cooperative_reduction_peer_combine( f"{result_var}_bval", src_dtype, default ) + index_dtype = self.features.select_index_dtype() peer_idx = self.codegen_cooperative_reduction_peer_combine( - result_var, dtype, 0 + result_var, index_dtype, torch.iinfo(index_dtype).max ) final_argreduce(self.post_loop_store, result_var, peer_val, peer_idx) elif is_welford_reduction(reduction_type):