From bd8d7b1b74b0ecc9da9f077a4b8c9b8801c3973f Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 2 Feb 2025 22:26:27 -0800 Subject: [PATCH] [Dynamo][Trace PyDispatcher] Remove disable from HigherOrderOperator.__call__ (#146270) Pull Request resolved: https://github.com/pytorch/pytorch/pull/146270 Approved by: https://github.com/zou3519 --- test/dynamo/test_higher_order_ops.py | 6 ++++-- torch/_ops.py | 5 ----- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 40cd37e344a..3704d9e5c53 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -2283,7 +2283,8 @@ def forward(self): res = mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]])) # There is graph break right when we enter body of map - self.assertEqual(len(backend.graphs), 0) + # Since we are tracing through the Python dispatch logic, it ends up 8 graphs. + self.assertEqual(len(backend.graphs), 8) self.assertEqual( res, mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]])) ) @@ -2319,7 +2320,8 @@ def forward(self): eager = mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]])) eager = mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]])) - self.assertEqual(len(backend.graphs), 0) + # Since we are tracing through the Python dispatch logic, it ends up 9 graphs. + self.assertEqual(len(backend.graphs), 9) self.assertEqual(res, eager) def test_wrap_subgraph_name_is_valid(self): diff --git a/torch/_ops.py b/torch/_ops.py index a33da4166f8..4df2de10539 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -456,11 +456,6 @@ class HigherOrderOperator(OperatorBase, abc.ABC): @abc.abstractmethod def __call__(self, /, *args, **kwargs): - # Dynamo already traces the body of HigherOrderOp beforehand when it - # so no need to trace into it. - from torch._dynamo import disable - - @disable def wrapper(): flat_args = _to_flat_tuple(args, kwargs) if torch.overrides.has_torch_function(flat_args):