From 6f07847efe94157b9ab0bdba080f952ecebfbdc3 Mon Sep 17 00:00:00 2001 From: bobrenjc93 Date: Thu, 23 Jan 2025 18:07:12 +0000 Subject: [PATCH] Bail on checking internal overlap when dealing with unbacked symints (#145385) Pull Request resolved: https://github.com/pytorch/pytorch/pull/145385 Approved by: https://github.com/ezyang --- aten/src/ATen/MemoryOverlap.cpp | 14 ++++++++++++-- test/dynamo/test_misc.py | 20 ++++++++++++++++++++ test/test_dynamic_shapes.py | 4 ++-- torch/_inductor/lowering.py | 2 +- torch/_inductor/sizevars.py | 10 ++++++++-- 5 files changed, 43 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/MemoryOverlap.cpp b/aten/src/ATen/MemoryOverlap.cpp index 0ed36ebfc8d..61336037d71 100644 --- a/aten/src/ATen/MemoryOverlap.cpp +++ b/aten/src/ATen/MemoryOverlap.cpp @@ -12,12 +12,22 @@ MemOverlap has_internal_overlap(const TensorBase& tensor) { MemOverlap has_internal_overlap(TensorImpl* t) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t->layout() == kStrided); + auto sizes = t->sym_sizes(); + auto strides = t->sym_strides(); + + // When we have unbacked symint strides, is_non_overlapping_and_dense + // often results in guard on data dependent errors. For now + // let us bail early if there are unbacked symint strides. + for (const auto i : c10::irange(strides.size())) { + if (!strides[i].has_hint()) { + return MemOverlap::TooHard; + } + } + if (t->is_non_overlapping_and_dense()) { return MemOverlap::No; } - auto strides = t->sym_strides(); - auto sizes = t->sym_sizes(); for (const auto i : c10::irange(strides.size())) { // NB: The size oblivious test is written very carefully here. When // unbacked SymInts are involved, we should try to conservatively report diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 73264cf29c8..5dd50bcc6f2 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -7614,6 +7614,26 @@ utils_device.CURRENT_DEVICE == None""".split( opt = torch.compile(fn, fullgraph=True) opt(*inputs) + @torch._dynamo.config.patch(capture_scalar_outputs=True) + @torch._dynamo.config.patch(assume_static_by_default=True) + def test_symint_copy_into_unbacked_slice(self): + @torch.compile() + def fn(a, x): + u0 = torch.tensor(x[0].to(torch.int64).item()).item() + B, H, T, D = a.shape + a_padding = torch.zeros((B, H, u0, D), dtype=torch.float64) + b = torch.cat([a, a_padding], dim=2) + c = torch.randn(B, H, 152, D) + b[:, :, :152, :] = c + return b + + x = torch.tensor([0]) + torch._dynamo.decorators.mark_unbacked(x, 0) + a = torch.zeros((1, 16, 152, 96)) + + # Previously would crash with guard on data dependent error + fn(a, x) + @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_symint_fold_nontrivial_product_modulo(self): @torch.compile(fullgraph=True) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 09c705ffa37..55191a9c927 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -1066,7 +1066,7 @@ def forward(self, x_1): self.assertEqual(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")), 0) self.assertEqual(cf(torch.empty_strided((2, u0), (1, 2), device="meta")), 0) self.assertEqual(cf(torch.empty_strided((u0,), (1,), device="meta")), 0) - self.assertEqual(cf(torch.empty_strided((1,), (u0,), device="meta")), 0) + self.assertEqual(cf(torch.empty_strided((1,), (u0,), device="meta")), 2) Max = torch.sym_max self.assertEqual( cf( @@ -1076,7 +1076,7 @@ def forward(self, x_1): device="meta", ) ), - 0, + 2, ) # Wobbling these to zero is OK too diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 9fc50cc7db9..2783180e915 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1009,7 +1009,7 @@ def squeeze(x, dim=None): for d, s in enumerate(x.get_size()): if not ( d in dims - and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1, size_oblivious=True)) + and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1), size_oblivious=True) ): new_shape.append(s) diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index cdb94455939..532073a377f 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -455,9 +455,15 @@ class SizeVarAllocator: # as this will ensure that you actually have a sympy'ified expression, # and will prevent you from incorrectly writing evaluate_expr(a == b) # which does the wrong thing if a or b is a sympy expression - def evaluate_expr(self, left: Union[Expr, sympy.logic.boolalg.Boolean]) -> bool: + def evaluate_expr( + self, + left: Union[Expr, sympy.logic.boolalg.Boolean], + size_oblivious: bool = False, + ) -> bool: assert isinstance(left, (Expr, sympy.logic.boolalg.Boolean)), type(left) - return self.shape_env.evaluate_expr(sympy.sympify(left)) + return self.shape_env.evaluate_expr( + sympy.sympify(left), size_oblivious=size_oblivious + ) def evaluate_min(self, left: Expr, right: Expr) -> Expr: """return the smaller of left and right, and guard on that choice"""