[inductor] Fix torch.split bug on unbacked symint (#113406)

torch.split(x, l) fails when l's shape is the unbacked symint.

E.g. l =
y.tolist() makes l the unbacked shape, because l depends on the
data access of y. The downdtream call `SliceView.create()`
evaluates the shape even if the input shape is unbacked symint,
which brings up the bug.

Test Plan:
python test/inductor/test_unbacked_symints.py -k test_split_with_sizes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113406
Approved by: https://github.com/aakhundov, https://github.com/ezyang
This commit is contained in:
chundian 2023-11-23 12:46:16 -08:00 committed by PyTorch MergeBot
parent 4aa2c51a09
commit 74e10f0f60
5 changed files with 29 additions and 6 deletions

View file

@ -54,6 +54,21 @@ class TestUnbackedSymints(TorchTestCase):
torch.testing.assert_close(actual, expected)
def test_split_with_sizes(self):
def fn(x, y):
l = y.tolist()
s = torch.split(x, l)
d = l[0] + l[1] + l[2]
return s[0].sum(), d
example_inputs = (torch.randn((32), device="cuda"), torch.tensor((7, 16, 9)))
with dynamo_config.patch({"capture_scalar_outputs": True}):
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
expected = fn(*example_inputs)
torch.testing.assert_close(actual, expected)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View file

@ -381,6 +381,10 @@ class PythonPrinter(ExprPrinter):
assert len(expr.args) >= 2
return f"max({', '.join(map(self._print, expr.args))})"
def _print_Min(self, expr):
assert len(expr.args) >= 2
return f"min({', '.join(map(self._print, expr.args))})"
class OpOverrides:
def __init__(self, parent):

View file

@ -19,6 +19,7 @@ import torch._logging
from torch._prims_common import is_integer_dtype
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from torch.utils._sympy.value_ranges import ValueRanges
from ..._dynamo.utils import counters
from .. import config, ir, scheduler
from ..codecache import code_hash, get_path, PyCodeCache
@ -1143,7 +1144,7 @@ class TritonKernel(Kernel):
# indirect indexing
cse_var = self.cse.varname_map[var.name]
mask_vars.update(cse_var.mask_vars)
elif var.name.startswith(("s", "ps")):
elif var.name.startswith(("s", "ps", "i")):
pass
else:
# var is one of xN, yN or rN

View file

@ -2088,11 +2088,12 @@ class SliceView(View):
start = cls.handle_negative_index(start, new_size[dim])
end = cls.handle_negative_index(end, new_size[dim])
end = sizevars.evaluate_min(end, new_size[dim])
start = sizevars.evaluate_min(start, end)
if start == 0 and sizevars.size_hint(end - new_size[dim]) == 0 and step == 1:
sizevars.guard_equals(end, new_size[dim])
return x
if free_unbacked_symbols(start) or free_unbacked_symbols(end):
end = sympy.Min(end, new_size[dim])
start = sympy.Min(start, end)
else:
end = sizevars.evaluate_min(end, new_size[dim])
start = sizevars.evaluate_min(start, end)
new_size[dim] = FloorDiv(end - start + (step - 1), step)

View file

@ -363,6 +363,8 @@ try:
"not_": z3.Not,
"floor": self._ops.floor,
"ceil": self._ops.ceil,
"minimum": self._ops.min,
"maximum": self._ops.max,
}
if name in REPLACEMENT: