From 618cc82e770cf758bc9cbcdfd4c61e2d73bdca25 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Fri, 23 Jun 2023 12:56:17 -0700 Subject: [PATCH] Stop Dynamo from peeking into wrap's body (#104076) When Dynamo sees `wrap(f, x)`, and it decides that `f` is unsafe, Dynamo should fall back to eager mode and stop introspection all the way throughout the call of `f`. The motivation is: - it's easier to test `wrap` this way (it is clearer how many graph breaks should occur) - Other HigherOrderOperator do this because their execution of the body involves code that is not necessarily Dynamo-able. e.g. functorch transforms. Since `wrap` is a test for the HigherOrderOp mechanism, it should reflect what other HigherOrderOps do. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104076 Approved by: https://github.com/ydwu4 --- test/dynamo/test_higher_order_ops.py | 16 +++++----------- torch/_higher_order_ops/wrap.py | 13 +++++++++++-- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 24230629fe1..6ff679cc22e 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -607,10 +607,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): result = f(x) self.assertEqual(result, inner(x)) - # It's unclear if this is correct: dynamo graph breaks on wrap but - # then interposes on wrap.__call__, which invokes fn(*args), - # leading to two graphs being compiled - self.assertEqual(cnt.frame_count, 2) + self.assertEqual(cnt.frame_count, 0) def test_fallback_on_graph_break_complicated(self): cnt = CompileCounter() @@ -631,10 +628,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): result = f(x) self.assertEqual(result, inner(x)) - # It's unclear if this is correct: dynamo graph breaks on wrap but - # then interposes on wrap.__call__, which invokes fn(*args), - # leading to four graphs being compiled: clone, sin, sin, clone - self.assertEqual(cnt.frame_count, 4) + self.assertEqual(cnt.frame_count, 2) def test_modules(self): counters.clear() @@ -688,7 +682,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): x = torch.randn(3) result = f(x) self.assertEqual(result, [1, torch.sin(x), 2.0]) - self.assertEqual(cnt.frame_count, 1) + self.assertEqual(cnt.frame_count, 0) self.assertEqual( dict(counters["graph_break"]), {"HigherOrderOperator body's output must consist of tensors only": 1}, @@ -708,7 +702,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): result = f(x) self.assertEqual(result, ((x.sin(), x.cos()),)) - self.assertEqual(cnt.frame_count, 1) + self.assertEqual(cnt.frame_count, 0) self.assertEqual( dict(counters["graph_break"]), {"HigherOrderOperator body's output must consist of tensors only": 1}, @@ -727,7 +721,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): x = torch.randn(3) result = f(x) self.assertEqual(result, [{"a": -x}]) - self.assertEqual(cnt.frame_count, 1) + self.assertEqual(cnt.frame_count, 0) self.assertEqual( dict(counters["graph_break"]), {"HigherOrderOperator body's output must consist of tensors only": 1}, diff --git a/torch/_higher_order_ops/wrap.py b/torch/_higher_order_ops/wrap.py index 6ef219d5263..36eb3b3f571 100644 --- a/torch/_higher_order_ops/wrap.py +++ b/torch/_higher_order_ops/wrap.py @@ -9,8 +9,17 @@ class Wrap(HigherOrderOperator): super().__init__("wrap", _deprecated_global_ns=True) def __call__(self, func, *args): - result = func(*args) - return result + # Dynamo already traces the body of HigherOrderOp beforehand when it + # so no need to trace into it. + import torch._dynamo # noqa: F401 + from torch._dynamo.eval_frame import disable + + @disable + def wrapper(): + result = func(*args) + return result + + return wrapper() wrap = Wrap()