mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[dynamo][guards] Delay DUPLICATE_INPUT guard because of incorrect ordering (#123605)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123605 Approved by: https://github.com/jansel ghstack dependencies: #123606
This commit is contained in:
parent
1dc4e1e335
commit
1346ebf12e
3 changed files with 42 additions and 40 deletions
|
|
@ -434,17 +434,15 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
|||
super().__init__()
|
||||
self.mean = torch.nn.Parameter(torch.randn(3, 3))
|
||||
|
||||
def forward(self, a, b, c, d, e, f):
|
||||
def forward(self, a, b, e, f):
|
||||
a.trunc_()
|
||||
b.trunc_()
|
||||
c.trunc_()
|
||||
d.trunc_()
|
||||
return (a + b + c + d + self.mean) * e * f
|
||||
return (a + b + self.mean) * e * f
|
||||
|
||||
a = torch.randn(3, 3, requires_grad=True)
|
||||
b = torch.randn(3, 3, requires_grad=True)
|
||||
a1, a2, a3, a4 = a.clone(), a.clone(), a.clone(), a.clone()
|
||||
b1, b2, b3, b4 = b.clone(), b.clone(), b.clone(), b.clone()
|
||||
a1, a2 = a.clone(), a.clone()
|
||||
b1, b2 = b.clone(), b.clone()
|
||||
|
||||
failure_reason = None
|
||||
|
||||
|
|
@ -457,8 +455,8 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
|||
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
|
||||
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
|
||||
f(a1, a1, a1, a1, 2, 2)
|
||||
f(a2, b2, b2, b2, 2, 2)
|
||||
f(a1, a1, 2, 2)
|
||||
f(a2, b2, 2, 2)
|
||||
self.assertEqual(cc.frame_count, 2)
|
||||
self.assertIn(
|
||||
"""L['a'] is L['b']""",
|
||||
|
|
@ -475,10 +473,10 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
|||
d3, d4 = d.clone(), d.clone()
|
||||
|
||||
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
|
||||
f(a3, b3, c3, c3, 3, 3)
|
||||
f(a4, b4, c4, d4, 3, 3)
|
||||
f(c3, c3, 3, 3)
|
||||
f(c4, d4, 3, 3)
|
||||
self.assertEqual(cc.frame_count, 2)
|
||||
self.assertIn("""L['c'] is L['d']""", failure_reason)
|
||||
self.assertIn("""L['a'] is L['b']""", failure_reason)
|
||||
|
||||
@patch("torch._functorch.config.debug_assert", True)
|
||||
def test_arg_dupe_via_dynamo_recompiles_many_with_global(self):
|
||||
|
|
@ -489,18 +487,16 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
|||
super().__init__()
|
||||
self.mean = torch.nn.Parameter(torch.randn(3, 3))
|
||||
|
||||
def forward(self, a, b, c, d, e, f):
|
||||
def forward(self, a, b, e, f):
|
||||
a.trunc_()
|
||||
b.trunc_()
|
||||
c.trunc_()
|
||||
d.trunc_()
|
||||
return (a + b + c + d + z + self.mean) * e * f
|
||||
return (a + b + z + self.mean) * e * f
|
||||
|
||||
a = torch.randn(3, 3, requires_grad=True)
|
||||
b = torch.randn(3, 3, requires_grad=True)
|
||||
z = a
|
||||
a1, a2, a3, a4 = a.clone(), a.clone(), a.clone(), a.clone()
|
||||
b1, b2, b3, b4 = b.clone(), b.clone(), b.clone(), b.clone()
|
||||
a1, a2 = a.clone(), a.clone()
|
||||
b1, b2 = b.clone(), b.clone()
|
||||
|
||||
failure_reason = None
|
||||
|
||||
|
|
@ -513,8 +509,8 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
|||
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
|
||||
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
|
||||
f(a1, a1, a1, a1, 2, 2)
|
||||
f(a2, b2, b2, b2, 2, 2)
|
||||
f(a1, a1, 2, 2)
|
||||
f(a2, b2, 2, 2)
|
||||
self.assertEqual(cc.frame_count, 2)
|
||||
self.assertIn(
|
||||
"""L['a'] is L['b']""",
|
||||
|
|
@ -528,17 +524,15 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
|||
super().__init__()
|
||||
self.mean = torch.nn.Parameter(torch.randn(3, 3))
|
||||
|
||||
def forward(self, e, f, a, b, c, d):
|
||||
def forward(self, e, f, a, b):
|
||||
a.trunc_()
|
||||
b.trunc_()
|
||||
c.trunc_()
|
||||
d.trunc_()
|
||||
return (a + b + c + d + self.mean) * e[0] * f[0]
|
||||
return (a + b + self.mean) * e[0] * f[0]
|
||||
|
||||
a = torch.randn(3, 3, requires_grad=True)
|
||||
b = torch.randn(3, 3, requires_grad=True)
|
||||
a1, a2, a3, a4 = a.clone(), a.clone(), a.clone(), a.clone()
|
||||
b1, b2, b3, b4 = b.clone(), b.clone(), b.clone(), b.clone()
|
||||
a1, a2 = a.clone(), a.clone()
|
||||
b1, b2 = b.clone(), b.clone()
|
||||
|
||||
failure_reason = None
|
||||
|
||||
|
|
@ -551,8 +545,8 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
|||
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
|
||||
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
|
||||
f([3, 2, 1], [4, 5, 6], a1, a1, a1, a1)
|
||||
f([3, 2, 1], [4, 5, 6], a2, b2, b2, b2)
|
||||
f([3, 2, 1], [4, 5, 6], a1, a1)
|
||||
f([3, 2, 1], [4, 5, 6], a2, b2)
|
||||
self.assertEqual(cc.frame_count, 2)
|
||||
self.assertIn(
|
||||
"""L['a'] is L['b']""",
|
||||
|
|
@ -569,8 +563,8 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
|||
d3, d4 = d.clone(), d.clone()
|
||||
|
||||
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
|
||||
f([3, 2, 1], [4, 5, 6], a3, b3, c3, c3)
|
||||
f([3, 2, 1], [4, 5, 6], a4, b4, c4, d4)
|
||||
f([3, 2, 1], [4, 5, 6], c3, c3)
|
||||
f([3, 2, 1], [4, 5, 6], c4, d4)
|
||||
self.assertEqual(cc.frame_count, 2)
|
||||
|
||||
@patch("torch._functorch.config.debug_assert", True)
|
||||
|
|
@ -580,17 +574,15 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
|||
super().__init__()
|
||||
self.mean = torch.nn.Parameter(torch.randn(3, 3))
|
||||
|
||||
def forward(self, a, b, c, d):
|
||||
def forward(self, a, b):
|
||||
a.trunc_()
|
||||
b.trunc_()
|
||||
c.trunc_()
|
||||
d.trunc_()
|
||||
return a + b + c + d + self.mean
|
||||
return a + b + self.mean
|
||||
|
||||
a = torch.randn(3, 3, requires_grad=True)
|
||||
b = torch.randn(3, 3, requires_grad=True)
|
||||
a1, a2, a3, a4 = a.clone(), a.clone(), a.clone(), a.clone()
|
||||
b1, b2, b3, b4 = b.clone(), b.clone(), b.clone(), b.clone()
|
||||
a1, a2 = a.clone(), a.clone()
|
||||
b1, b2 = b.clone(), b.clone()
|
||||
|
||||
failure_reason = None
|
||||
|
||||
|
|
@ -603,8 +595,8 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
|||
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
|
||||
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
|
||||
f(a1, a1, a1, a1)
|
||||
f(a2, b2, b2, b2)
|
||||
f(a1, a1)
|
||||
f(a2, b2)
|
||||
self.assertEqual(cc.frame_count, 2)
|
||||
self.assertIn(
|
||||
"""L['a'] is L['b']""",
|
||||
|
|
@ -621,10 +613,10 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
|||
d3, d4 = d.clone(), d.clone()
|
||||
|
||||
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
|
||||
f(a3, b3, c3, c3)
|
||||
f(a4, b4, c4, d4)
|
||||
f(c3, c3)
|
||||
f(c4, d4)
|
||||
self.assertEqual(cc.frame_count, 2)
|
||||
self.assertIn("""L['c'] is L['d']""", failure_reason)
|
||||
self.assertIn("""L['a'] is L['b']""", failure_reason)
|
||||
|
||||
@patch("torch._functorch.config.debug_assert", True)
|
||||
def test_arg_dupe_via_dynamo_recompiles_many_args(self):
|
||||
|
|
|
|||
|
|
@ -1744,6 +1744,7 @@ class CheckFunctionManager:
|
|||
continue
|
||||
|
||||
guard.create(builder)
|
||||
|
||||
self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn)
|
||||
# Keep track of weak references of objects with ID_MATCH guard. This
|
||||
# info is stored alongside optimized_code and check_fn and is used to
|
||||
|
|
|
|||
|
|
@ -172,7 +172,16 @@ class Guard:
|
|||
return self._hash
|
||||
|
||||
def sort_key(self):
|
||||
# Put the duplicate input guards at the end. The duplicate guards have
|
||||
# two sources while guard.name only considers one source.
|
||||
from ._dynamo.guards import GuardBuilder
|
||||
|
||||
is_duplicate_input = (
|
||||
isinstance(self.create_fn, functools.partial)
|
||||
and self.create_fn.func is GuardBuilder.DUPLICATE_INPUT
|
||||
)
|
||||
return (
|
||||
is_duplicate_input,
|
||||
self.source.value if self.source else -1,
|
||||
len(self.name),
|
||||
self.name,
|
||||
|
|
|
|||
Loading…
Reference in a new issue