[HigherOrderOp] makes control flow operators respect global decomp table (#120412)

A follow up of @zou3519 's comment on https://github.com/pytorch/pytorch/pull/120366. We create a helper method for this purpose.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120412
Approved by: https://github.com/zou3519
This commit is contained in:
ydwu4 2024-02-22 13:10:05 -08:00 committed by PyTorch MergeBot
parent 156954d6a2
commit 8f4ffd3d8a
4 changed files with 22 additions and 26 deletions

View file

@ -15,9 +15,9 @@ from torch._functorch.utils import exposed_in
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
_maybe_run_with_interpreter,
_set_compilation_env,
autograd_not_implemented,
reenter_make_fx,
UnsupportedAliasMutationException,
)
@ -25,7 +25,6 @@ from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
make_fx,
ProxyTorchDispatchMode,
track_tensor_tree,
)
@ -157,19 +156,8 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
pre_dispatch = getattr(proxy_mode, "pre_dispatch", False)
with disable_proxy_modes_tracing():
# We'll use the current decomposition table to make sure operatos in subgraphs are
# decomposed properly.
decomp_table = torch.fx.experimental.proxy_tensor.CURRENT_DECOMPOSITION_TABLE
true_graph = make_fx(
_maybe_run_with_interpreter(true_fn),
decomposition_table=decomp_table,
pre_dispatch=pre_dispatch,
)(*operands)
false_graph = make_fx(
_maybe_run_with_interpreter(false_fn),
decomposition_table=decomp_table,
pre_dispatch=pre_dispatch,
)(*operands)
true_graph = reenter_make_fx(true_fn, pre_dispatch)(*operands)
false_graph = reenter_make_fx(false_fn, pre_dispatch)(*operands)
true_outs = []
false_outs = []

View file

@ -7,6 +7,7 @@ from torch._functorch.aot_autograd import AOTConfig, create_joint, from_fun
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
reenter_make_fx,
UnsupportedAliasMutationException,
)
from torch._ops import HigherOrderOperator
@ -228,8 +229,9 @@ def trace_map(proxy_mode, func_overload, f, xs, pos_args):
example_input = _unstack_pytree(xs)[0]
body_graph = f
if not isinstance(body_graph, torch.fx.GraphModule):
body_graph = make_fx(body_graph)(*example_input, *pos_args)
pre_dispatch = getattr(proxy_mode, "pre_dispatch", False)
body_graph = reenter_make_fx(body_graph, pre_dispatch)(*example_input, *pos_args)
next_name = None
i = 0

View file

@ -76,6 +76,18 @@ def _maybe_run_with_interpreter(fn):
return maybe_interpreted_fn
# We'll use the current decomposition table to make sure operators in subgraphs are
# decomposed properly.
# We also need to maybe run with interpreter for propagating stack_trace
def reenter_make_fx(fn, pre_dispatch=False):
decomp_table = torch.fx.experimental.proxy_tensor.CURRENT_DECOMPOSITION_TABLE
return make_fx(
_maybe_run_with_interpreter(fn),
decomposition_table=decomp_table,
pre_dispatch=pre_dispatch,
)
@contextmanager
def _set_compilation_env():
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag

View file

@ -6,16 +6,15 @@ from torch._C import DispatchKey
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
_maybe_run_with_interpreter,
_set_compilation_env,
autograd_not_implemented,
reenter_make_fx,
UnsupportedAliasMutationException,
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
make_fx,
ProxyTorchDispatchMode,
track_tensor_tree,
)
@ -159,14 +158,9 @@ while_loop_op.py_impl(DispatchKey.Autograd)(
def while_loop_tracing(mode, cond_fn, body_fn, operands):
def _trace_while_loop(proxy_mode, while_loop_op, cond_fn, body_fn, operands):
pre_dispatch = getattr(proxy_mode, "pre_dispatch", False)
with disable_proxy_modes_tracing():
cond_graph = make_fx(
_maybe_run_with_interpreter(cond_fn), pre_dispatch=pre_dispatch
)(*operands)
body_graph = make_fx(
_maybe_run_with_interpreter(body_fn), pre_dispatch=pre_dispatch
)(*operands)
cond_graph = reenter_make_fx(cond_fn, pre_dispatch)(*operands)
body_graph = reenter_make_fx(body_fn, pre_dispatch)(*operands)
next_name = None
i = 0