From 8f4ffd3d8a97b8dcae784817e3dcadd41c88133b Mon Sep 17 00:00:00 2001 From: ydwu4 Date: Thu, 22 Feb 2024 13:10:05 -0800 Subject: [PATCH] [HigherOrderOp] makes control flow operators respect global decomp table (#120412) A follow up of @zou3519 's comment on https://github.com/pytorch/pytorch/pull/120366. We create a helper method for this purpose. Pull Request resolved: https://github.com/pytorch/pytorch/pull/120412 Approved by: https://github.com/zou3519 --- torch/_higher_order_ops/cond.py | 18 +++--------------- torch/_higher_order_ops/map.py | 6 ++++-- torch/_higher_order_ops/utils.py | 12 ++++++++++++ torch/_higher_order_ops/while_loop.py | 12 +++--------- 4 files changed, 22 insertions(+), 26 deletions(-) diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index 0dd5eb82369..cfc3d33bcb2 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -15,9 +15,9 @@ from torch._functorch.utils import exposed_in from torch._higher_order_ops.utils import ( _has_potential_branch_input_alias, _has_potential_branch_input_mutation, - _maybe_run_with_interpreter, _set_compilation_env, autograd_not_implemented, + reenter_make_fx, UnsupportedAliasMutationException, ) @@ -25,7 +25,6 @@ from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( disable_proxy_modes_tracing, - make_fx, ProxyTorchDispatchMode, track_tensor_tree, ) @@ -157,19 +156,8 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): pre_dispatch = getattr(proxy_mode, "pre_dispatch", False) with disable_proxy_modes_tracing(): - # We'll use the current decomposition table to make sure operatos in subgraphs are - # decomposed properly. - decomp_table = torch.fx.experimental.proxy_tensor.CURRENT_DECOMPOSITION_TABLE - true_graph = make_fx( - _maybe_run_with_interpreter(true_fn), - decomposition_table=decomp_table, - pre_dispatch=pre_dispatch, - )(*operands) - false_graph = make_fx( - _maybe_run_with_interpreter(false_fn), - decomposition_table=decomp_table, - pre_dispatch=pre_dispatch, - )(*operands) + true_graph = reenter_make_fx(true_fn, pre_dispatch)(*operands) + false_graph = reenter_make_fx(false_fn, pre_dispatch)(*operands) true_outs = [] false_outs = [] diff --git a/torch/_higher_order_ops/map.py b/torch/_higher_order_ops/map.py index 2f18925dfe3..b2d2025f0a4 100644 --- a/torch/_higher_order_ops/map.py +++ b/torch/_higher_order_ops/map.py @@ -7,6 +7,7 @@ from torch._functorch.aot_autograd import AOTConfig, create_joint, from_fun from torch._higher_order_ops.utils import ( _has_potential_branch_input_alias, _has_potential_branch_input_mutation, + reenter_make_fx, UnsupportedAliasMutationException, ) from torch._ops import HigherOrderOperator @@ -228,8 +229,9 @@ def trace_map(proxy_mode, func_overload, f, xs, pos_args): example_input = _unstack_pytree(xs)[0] body_graph = f - if not isinstance(body_graph, torch.fx.GraphModule): - body_graph = make_fx(body_graph)(*example_input, *pos_args) + + pre_dispatch = getattr(proxy_mode, "pre_dispatch", False) + body_graph = reenter_make_fx(body_graph, pre_dispatch)(*example_input, *pos_args) next_name = None i = 0 diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index b100debc618..d3b39b26415 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -76,6 +76,18 @@ def _maybe_run_with_interpreter(fn): return maybe_interpreted_fn +# We'll use the current decomposition table to make sure operators in subgraphs are +# decomposed properly. +# We also need to maybe run with interpreter for propagating stack_trace +def reenter_make_fx(fn, pre_dispatch=False): + decomp_table = torch.fx.experimental.proxy_tensor.CURRENT_DECOMPOSITION_TABLE + return make_fx( + _maybe_run_with_interpreter(fn), + decomposition_table=decomp_table, + pre_dispatch=pre_dispatch, + ) + + @contextmanager def _set_compilation_env(): _old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index 5b6826e0a6b..f2224e41aba 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -6,16 +6,15 @@ from torch._C import DispatchKey from torch._higher_order_ops.utils import ( _has_potential_branch_input_alias, _has_potential_branch_input_mutation, - _maybe_run_with_interpreter, _set_compilation_env, autograd_not_implemented, + reenter_make_fx, UnsupportedAliasMutationException, ) from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( disable_proxy_modes_tracing, - make_fx, ProxyTorchDispatchMode, track_tensor_tree, ) @@ -159,14 +158,9 @@ while_loop_op.py_impl(DispatchKey.Autograd)( def while_loop_tracing(mode, cond_fn, body_fn, operands): def _trace_while_loop(proxy_mode, while_loop_op, cond_fn, body_fn, operands): pre_dispatch = getattr(proxy_mode, "pre_dispatch", False) - with disable_proxy_modes_tracing(): - cond_graph = make_fx( - _maybe_run_with_interpreter(cond_fn), pre_dispatch=pre_dispatch - )(*operands) - body_graph = make_fx( - _maybe_run_with_interpreter(body_fn), pre_dispatch=pre_dispatch - )(*operands) + cond_graph = reenter_make_fx(cond_fn, pre_dispatch)(*operands) + body_graph = reenter_make_fx(body_fn, pre_dispatch)(*operands) next_name = None i = 0