diff --git a/test/inductor/test_unbacked_symints.py b/test/inductor/test_unbacked_symints.py index 2797bea8ceb..4ab72c6721b 100644 --- a/test/inductor/test_unbacked_symints.py +++ b/test/inductor/test_unbacked_symints.py @@ -54,6 +54,21 @@ class TestUnbackedSymints(TorchTestCase): torch.testing.assert_close(actual, expected) + def test_split_with_sizes(self): + def fn(x, y): + l = y.tolist() + s = torch.split(x, l) + d = l[0] + l[1] + l[2] + return s[0].sum(), d + + example_inputs = (torch.randn((32), device="cuda"), torch.tensor((7, 16, 9))) + + with dynamo_config.patch({"capture_scalar_outputs": True}): + actual = torch.compile(fn, fullgraph=True)(*example_inputs) + expected = fn(*example_inputs) + + torch.testing.assert_close(actual, expected) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 0b6fa9732bc..8eba8feaee1 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -381,6 +381,10 @@ class PythonPrinter(ExprPrinter): assert len(expr.args) >= 2 return f"max({', '.join(map(self._print, expr.args))})" + def _print_Min(self, expr): + assert len(expr.args) >= 2 + return f"min({', '.join(map(self._print, expr.args))})" + class OpOverrides: def __init__(self, parent): diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 2e970742084..1c21ff5fe95 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -19,6 +19,7 @@ import torch._logging from torch._prims_common import is_integer_dtype from torch.utils._sympy.functions import FloorDiv, ModularIndexing from torch.utils._sympy.value_ranges import ValueRanges + from ..._dynamo.utils import counters from .. import config, ir, scheduler from ..codecache import code_hash, get_path, PyCodeCache @@ -1143,7 +1144,7 @@ class TritonKernel(Kernel): # indirect indexing cse_var = self.cse.varname_map[var.name] mask_vars.update(cse_var.mask_vars) - elif var.name.startswith(("s", "ps")): + elif var.name.startswith(("s", "ps", "i")): pass else: # var is one of xN, yN or rN diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index bf9c9fff123..b3a22f58116 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -2088,11 +2088,12 @@ class SliceView(View): start = cls.handle_negative_index(start, new_size[dim]) end = cls.handle_negative_index(end, new_size[dim]) - end = sizevars.evaluate_min(end, new_size[dim]) - start = sizevars.evaluate_min(start, end) - if start == 0 and sizevars.size_hint(end - new_size[dim]) == 0 and step == 1: - sizevars.guard_equals(end, new_size[dim]) - return x + if free_unbacked_symbols(start) or free_unbacked_symbols(end): + end = sympy.Min(end, new_size[dim]) + start = sympy.Min(start, end) + else: + end = sizevars.evaluate_min(end, new_size[dim]) + start = sizevars.evaluate_min(start, end) new_size[dim] = FloorDiv(end - start + (step - 1), step) diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index 8c795bb5b3d..48ad07dd855 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -363,6 +363,8 @@ try: "not_": z3.Not, "floor": self._ops.floor, "ceil": self._ops.ceil, + "minimum": self._ops.min, + "maximum": self._ops.max, } if name in REPLACEMENT: