Stop Dynamo from peeking into wrap's body (#104076)

When Dynamo sees `wrap(f, x)`, and it decides that `f` is unsafe, Dynamo
should fall back to eager mode and stop introspection all the way
throughout the call of `f`. The motivation is:
- it's easier to test `wrap` this way (it is clearer how many graph
breaks should occur)
- Other HigherOrderOperator do this because their execution of the
body involves code that is not necessarily Dynamo-able. e.g. functorch
transforms. Since `wrap` is a test for the HigherOrderOp mechanism, it
should reflect what other HigherOrderOps do.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104076
Approved by: https://github.com/ydwu4
This commit is contained in:
Richard Zou 2023-06-23 12:56:17 -07:00 committed by PyTorch MergeBot
parent 5364366f8c
commit 618cc82e77
2 changed files with 16 additions and 13 deletions

View file

@ -607,10 +607,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
result = f(x)
self.assertEqual(result, inner(x))
# It's unclear if this is correct: dynamo graph breaks on wrap but
# then interposes on wrap.__call__, which invokes fn(*args),
# leading to two graphs being compiled
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.frame_count, 0)
def test_fallback_on_graph_break_complicated(self):
cnt = CompileCounter()
@ -631,10 +628,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
result = f(x)
self.assertEqual(result, inner(x))
# It's unclear if this is correct: dynamo graph breaks on wrap but
# then interposes on wrap.__call__, which invokes fn(*args),
# leading to four graphs being compiled: clone, sin, sin, clone
self.assertEqual(cnt.frame_count, 4)
self.assertEqual(cnt.frame_count, 2)
def test_modules(self):
counters.clear()
@ -688,7 +682,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
x = torch.randn(3)
result = f(x)
self.assertEqual(result, [1, torch.sin(x), 2.0])
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.frame_count, 0)
self.assertEqual(
dict(counters["graph_break"]),
{"HigherOrderOperator body's output must consist of tensors only": 1},
@ -708,7 +702,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
result = f(x)
self.assertEqual(result, ((x.sin(), x.cos()),))
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.frame_count, 0)
self.assertEqual(
dict(counters["graph_break"]),
{"HigherOrderOperator body's output must consist of tensors only": 1},
@ -727,7 +721,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
x = torch.randn(3)
result = f(x)
self.assertEqual(result, [{"a": -x}])
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.frame_count, 0)
self.assertEqual(
dict(counters["graph_break"]),
{"HigherOrderOperator body's output must consist of tensors only": 1},

View file

@ -9,8 +9,17 @@ class Wrap(HigherOrderOperator):
super().__init__("wrap", _deprecated_global_ns=True)
def __call__(self, func, *args):
result = func(*args)
return result
# Dynamo already traces the body of HigherOrderOp beforehand when it
# so no need to trace into it.
import torch._dynamo # noqa: F401
from torch._dynamo.eval_frame import disable
@disable
def wrapper():
result = func(*args)
return result
return wrapper()
wrap = Wrap()