mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[inductor] Use index_dtype (int32/int64 depending on size) for argmax accumulators (#146651)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146651 Approved by: https://github.com/shunting314, https://github.com/eellison
This commit is contained in:
parent
80a1696679
commit
04ce02182b
2 changed files with 22 additions and 3 deletions
|
|
@ -255,6 +255,23 @@ class TestFixedConfigs(TestCase):
|
|||
)
|
||||
]
|
||||
self._check(fn, args, persistent=persistent, cfg=cfg)
|
||||
args = [
|
||||
torch.stack(
|
||||
[
|
||||
torch.tensor(
|
||||
[0.0] * 150 + [float("inf")] * 150,
|
||||
device="cuda",
|
||||
dtype=torch.float32,
|
||||
),
|
||||
torch.tensor(
|
||||
[0.0] * 150 + [-float("inf")] * 150,
|
||||
device="cuda",
|
||||
dtype=torch.float32,
|
||||
),
|
||||
]
|
||||
)
|
||||
]
|
||||
self._check(fn, args, persistent=persistent, cfg=cfg)
|
||||
|
||||
@parametrize("persistent", [False, True])
|
||||
@parametrize("rsplit", [32, 33])
|
||||
|
|
|
|||
|
|
@ -2545,9 +2545,10 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
|
||||
if reduction_type in ("argmax", "argmin"):
|
||||
accumulator_index = f"_{result_var}_index"
|
||||
long_max = torch.iinfo(torch.int64).max
|
||||
index_dtype = self.features.select_index_dtype()
|
||||
self.body.writeline(
|
||||
f"{accumulator_index} = tl.full({self.dense_size_str()}, {long_max}, tl.int64)"
|
||||
f"{accumulator_index} = tl.full({self.dense_size_str()}, "
|
||||
f"{torch.iinfo(index_dtype).max}, {self.dtype_to_str(index_dtype)})"
|
||||
)
|
||||
root_op = {"argmax": "max", "argmin": "min"}[reduction_type]
|
||||
|
||||
|
|
@ -2609,8 +2610,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
peer_val = self.codegen_cooperative_reduction_peer_combine(
|
||||
f"{result_var}_bval", src_dtype, default
|
||||
)
|
||||
index_dtype = self.features.select_index_dtype()
|
||||
peer_idx = self.codegen_cooperative_reduction_peer_combine(
|
||||
result_var, dtype, 0
|
||||
result_var, index_dtype, torch.iinfo(index_dtype).max
|
||||
)
|
||||
final_argreduce(self.post_loop_store, result_var, peer_val, peer_idx)
|
||||
elif is_welford_reduction(reduction_type):
|
||||
|
|
|
|||
Loading…
Reference in a new issue