mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Stop Dynamo from peeking into wrap's body (#104076)
When Dynamo sees `wrap(f, x)`, and it decides that `f` is unsafe, Dynamo should fall back to eager mode and stop introspection all the way throughout the call of `f`. The motivation is: - it's easier to test `wrap` this way (it is clearer how many graph breaks should occur) - Other HigherOrderOperator do this because their execution of the body involves code that is not necessarily Dynamo-able. e.g. functorch transforms. Since `wrap` is a test for the HigherOrderOp mechanism, it should reflect what other HigherOrderOps do. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104076 Approved by: https://github.com/ydwu4
This commit is contained in:
parent
5364366f8c
commit
618cc82e77
2 changed files with 16 additions and 13 deletions
|
|
@ -607,10 +607,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
|
||||
result = f(x)
|
||||
self.assertEqual(result, inner(x))
|
||||
# It's unclear if this is correct: dynamo graph breaks on wrap but
|
||||
# then interposes on wrap.__call__, which invokes fn(*args),
|
||||
# leading to two graphs being compiled
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
self.assertEqual(cnt.frame_count, 0)
|
||||
|
||||
def test_fallback_on_graph_break_complicated(self):
|
||||
cnt = CompileCounter()
|
||||
|
|
@ -631,10 +628,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
|
||||
result = f(x)
|
||||
self.assertEqual(result, inner(x))
|
||||
# It's unclear if this is correct: dynamo graph breaks on wrap but
|
||||
# then interposes on wrap.__call__, which invokes fn(*args),
|
||||
# leading to four graphs being compiled: clone, sin, sin, clone
|
||||
self.assertEqual(cnt.frame_count, 4)
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
def test_modules(self):
|
||||
counters.clear()
|
||||
|
|
@ -688,7 +682,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
x = torch.randn(3)
|
||||
result = f(x)
|
||||
self.assertEqual(result, [1, torch.sin(x), 2.0])
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(cnt.frame_count, 0)
|
||||
self.assertEqual(
|
||||
dict(counters["graph_break"]),
|
||||
{"HigherOrderOperator body's output must consist of tensors only": 1},
|
||||
|
|
@ -708,7 +702,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
result = f(x)
|
||||
|
||||
self.assertEqual(result, ((x.sin(), x.cos()),))
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(cnt.frame_count, 0)
|
||||
self.assertEqual(
|
||||
dict(counters["graph_break"]),
|
||||
{"HigherOrderOperator body's output must consist of tensors only": 1},
|
||||
|
|
@ -727,7 +721,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
x = torch.randn(3)
|
||||
result = f(x)
|
||||
self.assertEqual(result, [{"a": -x}])
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(cnt.frame_count, 0)
|
||||
self.assertEqual(
|
||||
dict(counters["graph_break"]),
|
||||
{"HigherOrderOperator body's output must consist of tensors only": 1},
|
||||
|
|
|
|||
|
|
@ -9,8 +9,17 @@ class Wrap(HigherOrderOperator):
|
|||
super().__init__("wrap", _deprecated_global_ns=True)
|
||||
|
||||
def __call__(self, func, *args):
|
||||
result = func(*args)
|
||||
return result
|
||||
# Dynamo already traces the body of HigherOrderOp beforehand when it
|
||||
# so no need to trace into it.
|
||||
import torch._dynamo # noqa: F401
|
||||
from torch._dynamo.eval_frame import disable
|
||||
|
||||
@disable
|
||||
def wrapper():
|
||||
result = func(*args)
|
||||
return result
|
||||
|
||||
return wrapper()
|
||||
|
||||
wrap = Wrap()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue