[Dynamo][Trace PyDispatcher] Remove disable from HigherOrderOperator.__call__ (#146270)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146270
Approved by: https://github.com/zou3519
This commit is contained in:
Yanbo Liang 2025-02-02 22:26:27 -08:00 committed by PyTorch MergeBot
parent fd73ae2068
commit bd8d7b1b74
2 changed files with 4 additions and 7 deletions

View file

@ -2283,7 +2283,8 @@ def forward(self):
res = mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
# There is graph break right when we enter body of map
self.assertEqual(len(backend.graphs), 0)
# Since we are tracing through the Python dispatch logic, it ends up 8 graphs.
self.assertEqual(len(backend.graphs), 8)
self.assertEqual(
res, mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
)
@ -2319,7 +2320,8 @@ def forward(self):
eager = mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
eager = mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
self.assertEqual(len(backend.graphs), 0)
# Since we are tracing through the Python dispatch logic, it ends up 9 graphs.
self.assertEqual(len(backend.graphs), 9)
self.assertEqual(res, eager)
def test_wrap_subgraph_name_is_valid(self):

View file

@ -456,11 +456,6 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
@abc.abstractmethod
def __call__(self, /, *args, **kwargs):
# Dynamo already traces the body of HigherOrderOp beforehand when it
# so no need to trace into it.
from torch._dynamo import disable
@disable
def wrapper():
flat_args = _to_flat_tuple(args, kwargs)
if torch.overrides.has_torch_function(flat_args):