diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py index c03bcb406ce..84b1e0a28c0 100644 --- a/test/dynamo/test_aot_autograd.py +++ b/test/dynamo/test_aot_autograd.py @@ -434,17 +434,15 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): super().__init__() self.mean = torch.nn.Parameter(torch.randn(3, 3)) - def forward(self, a, b, c, d, e, f): + def forward(self, a, b, e, f): a.trunc_() b.trunc_() - c.trunc_() - d.trunc_() - return (a + b + c + d + self.mean) * e * f + return (a + b + self.mean) * e * f a = torch.randn(3, 3, requires_grad=True) b = torch.randn(3, 3, requires_grad=True) - a1, a2, a3, a4 = a.clone(), a.clone(), a.clone(), a.clone() - b1, b2, b3, b4 = b.clone(), b.clone(), b.clone(), b.clone() + a1, a2 = a.clone(), a.clone() + b1, b2 = b.clone(), b.clone() failure_reason = None @@ -457,8 +455,8 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) - f(a1, a1, a1, a1, 2, 2) - f(a2, b2, b2, b2, 2, 2) + f(a1, a1, 2, 2) + f(a2, b2, 2, 2) self.assertEqual(cc.frame_count, 2) self.assertIn( """L['a'] is L['b']""", @@ -475,10 +473,10 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): d3, d4 = d.clone(), d.clone() f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) - f(a3, b3, c3, c3, 3, 3) - f(a4, b4, c4, d4, 3, 3) + f(c3, c3, 3, 3) + f(c4, d4, 3, 3) self.assertEqual(cc.frame_count, 2) - self.assertIn("""L['c'] is L['d']""", failure_reason) + self.assertIn("""L['a'] is L['b']""", failure_reason) @patch("torch._functorch.config.debug_assert", True) def test_arg_dupe_via_dynamo_recompiles_many_with_global(self): @@ -489,18 +487,16 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): super().__init__() self.mean = torch.nn.Parameter(torch.randn(3, 3)) - def forward(self, a, b, c, d, e, f): + def forward(self, a, b, e, f): a.trunc_() b.trunc_() - c.trunc_() - d.trunc_() - return (a + b + c + d + z + self.mean) * e * f + return (a + b + z + self.mean) * e * f a = torch.randn(3, 3, requires_grad=True) b = torch.randn(3, 3, requires_grad=True) z = a - a1, a2, a3, a4 = a.clone(), a.clone(), a.clone(), a.clone() - b1, b2, b3, b4 = b.clone(), b.clone(), b.clone(), b.clone() + a1, a2 = a.clone(), a.clone() + b1, b2 = b.clone(), b.clone() failure_reason = None @@ -513,8 +509,8 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) - f(a1, a1, a1, a1, 2, 2) - f(a2, b2, b2, b2, 2, 2) + f(a1, a1, 2, 2) + f(a2, b2, 2, 2) self.assertEqual(cc.frame_count, 2) self.assertIn( """L['a'] is L['b']""", @@ -528,17 +524,15 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): super().__init__() self.mean = torch.nn.Parameter(torch.randn(3, 3)) - def forward(self, e, f, a, b, c, d): + def forward(self, e, f, a, b): a.trunc_() b.trunc_() - c.trunc_() - d.trunc_() - return (a + b + c + d + self.mean) * e[0] * f[0] + return (a + b + self.mean) * e[0] * f[0] a = torch.randn(3, 3, requires_grad=True) b = torch.randn(3, 3, requires_grad=True) - a1, a2, a3, a4 = a.clone(), a.clone(), a.clone(), a.clone() - b1, b2, b3, b4 = b.clone(), b.clone(), b.clone(), b.clone() + a1, a2 = a.clone(), a.clone() + b1, b2 = b.clone(), b.clone() failure_reason = None @@ -551,8 +545,8 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) - f([3, 2, 1], [4, 5, 6], a1, a1, a1, a1) - f([3, 2, 1], [4, 5, 6], a2, b2, b2, b2) + f([3, 2, 1], [4, 5, 6], a1, a1) + f([3, 2, 1], [4, 5, 6], a2, b2) self.assertEqual(cc.frame_count, 2) self.assertIn( """L['a'] is L['b']""", @@ -569,8 +563,8 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): d3, d4 = d.clone(), d.clone() f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) - f([3, 2, 1], [4, 5, 6], a3, b3, c3, c3) - f([3, 2, 1], [4, 5, 6], a4, b4, c4, d4) + f([3, 2, 1], [4, 5, 6], c3, c3) + f([3, 2, 1], [4, 5, 6], c4, d4) self.assertEqual(cc.frame_count, 2) @patch("torch._functorch.config.debug_assert", True) @@ -580,17 +574,15 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): super().__init__() self.mean = torch.nn.Parameter(torch.randn(3, 3)) - def forward(self, a, b, c, d): + def forward(self, a, b): a.trunc_() b.trunc_() - c.trunc_() - d.trunc_() - return a + b + c + d + self.mean + return a + b + self.mean a = torch.randn(3, 3, requires_grad=True) b = torch.randn(3, 3, requires_grad=True) - a1, a2, a3, a4 = a.clone(), a.clone(), a.clone(), a.clone() - b1, b2, b3, b4 = b.clone(), b.clone(), b.clone(), b.clone() + a1, a2 = a.clone(), a.clone() + b1, b2 = b.clone(), b.clone() failure_reason = None @@ -603,8 +595,8 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) - f(a1, a1, a1, a1) - f(a2, b2, b2, b2) + f(a1, a1) + f(a2, b2) self.assertEqual(cc.frame_count, 2) self.assertIn( """L['a'] is L['b']""", @@ -621,10 +613,10 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): d3, d4 = d.clone(), d.clone() f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) - f(a3, b3, c3, c3) - f(a4, b4, c4, d4) + f(c3, c3) + f(c4, d4) self.assertEqual(cc.frame_count, 2) - self.assertIn("""L['c'] is L['d']""", failure_reason) + self.assertIn("""L['a'] is L['b']""", failure_reason) @patch("torch._functorch.config.debug_assert", True) def test_arg_dupe_via_dynamo_recompiles_many_args(self): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 6d5c87f236c..2f9354f22f9 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1744,6 +1744,7 @@ class CheckFunctionManager: continue guard.create(builder) + self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn) # Keep track of weak references of objects with ID_MATCH guard. This # info is stored alongside optimized_code and check_fn and is used to diff --git a/torch/_guards.py b/torch/_guards.py index f0478659fad..09ed4a85b3c 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -172,7 +172,16 @@ class Guard: return self._hash def sort_key(self): + # Put the duplicate input guards at the end. The duplicate guards have + # two sources while guard.name only considers one source. + from ._dynamo.guards import GuardBuilder + + is_duplicate_input = ( + isinstance(self.create_fn, functools.partial) + and self.create_fn.func is GuardBuilder.DUPLICATE_INPUT + ) return ( + is_duplicate_input, self.source.value if self.source else -1, len(self.name), self.name,