mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Dynamo][Trace PyDispatcher] Remove disable from HigherOrderOperator.__call__ (#146270)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146270 Approved by: https://github.com/zou3519
This commit is contained in:
parent
fd73ae2068
commit
bd8d7b1b74
2 changed files with 4 additions and 7 deletions
|
|
@ -2283,7 +2283,8 @@ def forward(self):
|
|||
|
||||
res = mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
|
||||
# There is graph break right when we enter body of map
|
||||
self.assertEqual(len(backend.graphs), 0)
|
||||
# Since we are tracing through the Python dispatch logic, it ends up 8 graphs.
|
||||
self.assertEqual(len(backend.graphs), 8)
|
||||
self.assertEqual(
|
||||
res, mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
|
||||
)
|
||||
|
|
@ -2319,7 +2320,8 @@ def forward(self):
|
|||
eager = mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
|
||||
eager = mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
|
||||
|
||||
self.assertEqual(len(backend.graphs), 0)
|
||||
# Since we are tracing through the Python dispatch logic, it ends up 9 graphs.
|
||||
self.assertEqual(len(backend.graphs), 9)
|
||||
self.assertEqual(res, eager)
|
||||
|
||||
def test_wrap_subgraph_name_is_valid(self):
|
||||
|
|
|
|||
|
|
@ -456,11 +456,6 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
|
|||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, /, *args, **kwargs):
|
||||
# Dynamo already traces the body of HigherOrderOp beforehand when it
|
||||
# so no need to trace into it.
|
||||
from torch._dynamo import disable
|
||||
|
||||
@disable
|
||||
def wrapper():
|
||||
flat_args = _to_flat_tuple(args, kwargs)
|
||||
if torch.overrides.has_torch_function(flat_args):
|
||||
|
|
|
|||
Loading…
Reference in a new issue