[inductor] Use index_dtype (int32/int64 depending on size) for argmax accumulators (#146651)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146651
Approved by: https://github.com/shunting314, https://github.com/eellison
This commit is contained in:
Jason Ansel 2025-02-06 14:56:59 -08:00 committed by PyTorch MergeBot
parent 80a1696679
commit 04ce02182b
2 changed files with 22 additions and 3 deletions

View file

@ -255,6 +255,23 @@ class TestFixedConfigs(TestCase):
)
]
self._check(fn, args, persistent=persistent, cfg=cfg)
args = [
torch.stack(
[
torch.tensor(
[0.0] * 150 + [float("inf")] * 150,
device="cuda",
dtype=torch.float32,
),
torch.tensor(
[0.0] * 150 + [-float("inf")] * 150,
device="cuda",
dtype=torch.float32,
),
]
)
]
self._check(fn, args, persistent=persistent, cfg=cfg)
@parametrize("persistent", [False, True])
@parametrize("rsplit", [32, 33])

View file

@ -2545,9 +2545,10 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
if reduction_type in ("argmax", "argmin"):
accumulator_index = f"_{result_var}_index"
long_max = torch.iinfo(torch.int64).max
index_dtype = self.features.select_index_dtype()
self.body.writeline(
f"{accumulator_index} = tl.full({self.dense_size_str()}, {long_max}, tl.int64)"
f"{accumulator_index} = tl.full({self.dense_size_str()}, "
f"{torch.iinfo(index_dtype).max}, {self.dtype_to_str(index_dtype)})"
)
root_op = {"argmax": "max", "argmin": "min"}[reduction_type]
@ -2609,8 +2610,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
peer_val = self.codegen_cooperative_reduction_peer_combine(
f"{result_var}_bval", src_dtype, default
)
index_dtype = self.features.select_index_dtype()
peer_idx = self.codegen_cooperative_reduction_peer_combine(
result_var, dtype, 0
result_var, index_dtype, torch.iinfo(index_dtype).max
)
final_argreduce(self.post_loop_store, result_var, peer_val, peer_idx)
elif is_welford_reduction(reduction_type):