[dynamo][guards] Delay DUPLICATE_INPUT guard because of incorrect ordering (#123605)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123605
Approved by: https://github.com/jansel
ghstack dependencies: #123606
This commit is contained in:
Animesh Jain 2024-04-09 21:17:45 -07:00 committed by PyTorch MergeBot
parent 1dc4e1e335
commit 1346ebf12e
3 changed files with 42 additions and 40 deletions

View file

@ -434,17 +434,15 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
super().__init__()
self.mean = torch.nn.Parameter(torch.randn(3, 3))
def forward(self, a, b, c, d, e, f):
def forward(self, a, b, e, f):
a.trunc_()
b.trunc_()
c.trunc_()
d.trunc_()
return (a + b + c + d + self.mean) * e * f
return (a + b + self.mean) * e * f
a = torch.randn(3, 3, requires_grad=True)
b = torch.randn(3, 3, requires_grad=True)
a1, a2, a3, a4 = a.clone(), a.clone(), a.clone(), a.clone()
b1, b2, b3, b4 = b.clone(), b.clone(), b.clone(), b.clone()
a1, a2 = a.clone(), a.clone()
b1, b2 = b.clone(), b.clone()
failure_reason = None
@ -457,8 +455,8 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
f(a1, a1, a1, a1, 2, 2)
f(a2, b2, b2, b2, 2, 2)
f(a1, a1, 2, 2)
f(a2, b2, 2, 2)
self.assertEqual(cc.frame_count, 2)
self.assertIn(
"""L['a'] is L['b']""",
@ -475,10 +473,10 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
d3, d4 = d.clone(), d.clone()
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
f(a3, b3, c3, c3, 3, 3)
f(a4, b4, c4, d4, 3, 3)
f(c3, c3, 3, 3)
f(c4, d4, 3, 3)
self.assertEqual(cc.frame_count, 2)
self.assertIn("""L['c'] is L['d']""", failure_reason)
self.assertIn("""L['a'] is L['b']""", failure_reason)
@patch("torch._functorch.config.debug_assert", True)
def test_arg_dupe_via_dynamo_recompiles_many_with_global(self):
@ -489,18 +487,16 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
super().__init__()
self.mean = torch.nn.Parameter(torch.randn(3, 3))
def forward(self, a, b, c, d, e, f):
def forward(self, a, b, e, f):
a.trunc_()
b.trunc_()
c.trunc_()
d.trunc_()
return (a + b + c + d + z + self.mean) * e * f
return (a + b + z + self.mean) * e * f
a = torch.randn(3, 3, requires_grad=True)
b = torch.randn(3, 3, requires_grad=True)
z = a
a1, a2, a3, a4 = a.clone(), a.clone(), a.clone(), a.clone()
b1, b2, b3, b4 = b.clone(), b.clone(), b.clone(), b.clone()
a1, a2 = a.clone(), a.clone()
b1, b2 = b.clone(), b.clone()
failure_reason = None
@ -513,8 +509,8 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
f(a1, a1, a1, a1, 2, 2)
f(a2, b2, b2, b2, 2, 2)
f(a1, a1, 2, 2)
f(a2, b2, 2, 2)
self.assertEqual(cc.frame_count, 2)
self.assertIn(
"""L['a'] is L['b']""",
@ -528,17 +524,15 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
super().__init__()
self.mean = torch.nn.Parameter(torch.randn(3, 3))
def forward(self, e, f, a, b, c, d):
def forward(self, e, f, a, b):
a.trunc_()
b.trunc_()
c.trunc_()
d.trunc_()
return (a + b + c + d + self.mean) * e[0] * f[0]
return (a + b + self.mean) * e[0] * f[0]
a = torch.randn(3, 3, requires_grad=True)
b = torch.randn(3, 3, requires_grad=True)
a1, a2, a3, a4 = a.clone(), a.clone(), a.clone(), a.clone()
b1, b2, b3, b4 = b.clone(), b.clone(), b.clone(), b.clone()
a1, a2 = a.clone(), a.clone()
b1, b2 = b.clone(), b.clone()
failure_reason = None
@ -551,8 +545,8 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
f([3, 2, 1], [4, 5, 6], a1, a1, a1, a1)
f([3, 2, 1], [4, 5, 6], a2, b2, b2, b2)
f([3, 2, 1], [4, 5, 6], a1, a1)
f([3, 2, 1], [4, 5, 6], a2, b2)
self.assertEqual(cc.frame_count, 2)
self.assertIn(
"""L['a'] is L['b']""",
@ -569,8 +563,8 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
d3, d4 = d.clone(), d.clone()
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
f([3, 2, 1], [4, 5, 6], a3, b3, c3, c3)
f([3, 2, 1], [4, 5, 6], a4, b4, c4, d4)
f([3, 2, 1], [4, 5, 6], c3, c3)
f([3, 2, 1], [4, 5, 6], c4, d4)
self.assertEqual(cc.frame_count, 2)
@patch("torch._functorch.config.debug_assert", True)
@ -580,17 +574,15 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
super().__init__()
self.mean = torch.nn.Parameter(torch.randn(3, 3))
def forward(self, a, b, c, d):
def forward(self, a, b):
a.trunc_()
b.trunc_()
c.trunc_()
d.trunc_()
return a + b + c + d + self.mean
return a + b + self.mean
a = torch.randn(3, 3, requires_grad=True)
b = torch.randn(3, 3, requires_grad=True)
a1, a2, a3, a4 = a.clone(), a.clone(), a.clone(), a.clone()
b1, b2, b3, b4 = b.clone(), b.clone(), b.clone(), b.clone()
a1, a2 = a.clone(), a.clone()
b1, b2 = b.clone(), b.clone()
failure_reason = None
@ -603,8 +595,8 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
f(a1, a1, a1, a1)
f(a2, b2, b2, b2)
f(a1, a1)
f(a2, b2)
self.assertEqual(cc.frame_count, 2)
self.assertIn(
"""L['a'] is L['b']""",
@ -621,10 +613,10 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
d3, d4 = d.clone(), d.clone()
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
f(a3, b3, c3, c3)
f(a4, b4, c4, d4)
f(c3, c3)
f(c4, d4)
self.assertEqual(cc.frame_count, 2)
self.assertIn("""L['c'] is L['d']""", failure_reason)
self.assertIn("""L['a'] is L['b']""", failure_reason)
@patch("torch._functorch.config.debug_assert", True)
def test_arg_dupe_via_dynamo_recompiles_many_args(self):

View file

@ -1744,6 +1744,7 @@ class CheckFunctionManager:
continue
guard.create(builder)
self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn)
# Keep track of weak references of objects with ID_MATCH guard. This
# info is stored alongside optimized_code and check_fn and is used to

View file

@ -172,7 +172,16 @@ class Guard:
return self._hash
def sort_key(self):
# Put the duplicate input guards at the end. The duplicate guards have
# two sources while guard.name only considers one source.
from ._dynamo.guards import GuardBuilder
is_duplicate_input = (
isinstance(self.create_fn, functools.partial)
and self.create_fn.func is GuardBuilder.DUPLICATE_INPUT
)
return (
is_duplicate_input,
self.source.value if self.source else -1,
len(self.name),
self.name,