mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[inductor] Fix torch.split bug on unbacked symint (#113406)
torch.split(x, l) fails when l's shape is the unbacked symint. E.g. l = y.tolist() makes l the unbacked shape, because l depends on the data access of y. The downdtream call `SliceView.create()` evaluates the shape even if the input shape is unbacked symint, which brings up the bug. Test Plan: python test/inductor/test_unbacked_symints.py -k test_split_with_sizes Pull Request resolved: https://github.com/pytorch/pytorch/pull/113406 Approved by: https://github.com/aakhundov, https://github.com/ezyang
This commit is contained in:
parent
4aa2c51a09
commit
74e10f0f60
5 changed files with 29 additions and 6 deletions
|
|
@ -54,6 +54,21 @@ class TestUnbackedSymints(TorchTestCase):
|
|||
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
def test_split_with_sizes(self):
|
||||
def fn(x, y):
|
||||
l = y.tolist()
|
||||
s = torch.split(x, l)
|
||||
d = l[0] + l[1] + l[2]
|
||||
return s[0].sum(), d
|
||||
|
||||
example_inputs = (torch.randn((32), device="cuda"), torch.tensor((7, 16, 9)))
|
||||
|
||||
with dynamo_config.patch({"capture_scalar_outputs": True}):
|
||||
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
|
||||
expected = fn(*example_inputs)
|
||||
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -381,6 +381,10 @@ class PythonPrinter(ExprPrinter):
|
|||
assert len(expr.args) >= 2
|
||||
return f"max({', '.join(map(self._print, expr.args))})"
|
||||
|
||||
def _print_Min(self, expr):
|
||||
assert len(expr.args) >= 2
|
||||
return f"min({', '.join(map(self._print, expr.args))})"
|
||||
|
||||
|
||||
class OpOverrides:
|
||||
def __init__(self, parent):
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import torch._logging
|
|||
from torch._prims_common import is_integer_dtype
|
||||
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
from ..._dynamo.utils import counters
|
||||
from .. import config, ir, scheduler
|
||||
from ..codecache import code_hash, get_path, PyCodeCache
|
||||
|
|
@ -1143,7 +1144,7 @@ class TritonKernel(Kernel):
|
|||
# indirect indexing
|
||||
cse_var = self.cse.varname_map[var.name]
|
||||
mask_vars.update(cse_var.mask_vars)
|
||||
elif var.name.startswith(("s", "ps")):
|
||||
elif var.name.startswith(("s", "ps", "i")):
|
||||
pass
|
||||
else:
|
||||
# var is one of xN, yN or rN
|
||||
|
|
|
|||
|
|
@ -2088,11 +2088,12 @@ class SliceView(View):
|
|||
start = cls.handle_negative_index(start, new_size[dim])
|
||||
end = cls.handle_negative_index(end, new_size[dim])
|
||||
|
||||
end = sizevars.evaluate_min(end, new_size[dim])
|
||||
start = sizevars.evaluate_min(start, end)
|
||||
if start == 0 and sizevars.size_hint(end - new_size[dim]) == 0 and step == 1:
|
||||
sizevars.guard_equals(end, new_size[dim])
|
||||
return x
|
||||
if free_unbacked_symbols(start) or free_unbacked_symbols(end):
|
||||
end = sympy.Min(end, new_size[dim])
|
||||
start = sympy.Min(start, end)
|
||||
else:
|
||||
end = sizevars.evaluate_min(end, new_size[dim])
|
||||
start = sizevars.evaluate_min(start, end)
|
||||
|
||||
new_size[dim] = FloorDiv(end - start + (step - 1), step)
|
||||
|
||||
|
|
|
|||
|
|
@ -363,6 +363,8 @@ try:
|
|||
"not_": z3.Not,
|
||||
"floor": self._ops.floor,
|
||||
"ceil": self._ops.ceil,
|
||||
"minimum": self._ops.min,
|
||||
"maximum": self._ops.max,
|
||||
}
|
||||
|
||||
if name in REPLACEMENT:
|
||||
|
|
|
|||
Loading…
Reference in a new issue