mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[HigherOrderOp] makes control flow operators respect global decomp table (#120412)
A follow up of @zou3519 's comment on https://github.com/pytorch/pytorch/pull/120366. We create a helper method for this purpose. Pull Request resolved: https://github.com/pytorch/pytorch/pull/120412 Approved by: https://github.com/zou3519
This commit is contained in:
parent
156954d6a2
commit
8f4ffd3d8a
4 changed files with 22 additions and 26 deletions
|
|
@ -15,9 +15,9 @@ from torch._functorch.utils import exposed_in
|
|||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
_maybe_run_with_interpreter,
|
||||
_set_compilation_env,
|
||||
autograd_not_implemented,
|
||||
reenter_make_fx,
|
||||
UnsupportedAliasMutationException,
|
||||
)
|
||||
|
||||
|
|
@ -25,7 +25,6 @@ from torch._ops import HigherOrderOperator
|
|||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
disable_proxy_modes_tracing,
|
||||
make_fx,
|
||||
ProxyTorchDispatchMode,
|
||||
track_tensor_tree,
|
||||
)
|
||||
|
|
@ -157,19 +156,8 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
|||
pre_dispatch = getattr(proxy_mode, "pre_dispatch", False)
|
||||
|
||||
with disable_proxy_modes_tracing():
|
||||
# We'll use the current decomposition table to make sure operatos in subgraphs are
|
||||
# decomposed properly.
|
||||
decomp_table = torch.fx.experimental.proxy_tensor.CURRENT_DECOMPOSITION_TABLE
|
||||
true_graph = make_fx(
|
||||
_maybe_run_with_interpreter(true_fn),
|
||||
decomposition_table=decomp_table,
|
||||
pre_dispatch=pre_dispatch,
|
||||
)(*operands)
|
||||
false_graph = make_fx(
|
||||
_maybe_run_with_interpreter(false_fn),
|
||||
decomposition_table=decomp_table,
|
||||
pre_dispatch=pre_dispatch,
|
||||
)(*operands)
|
||||
true_graph = reenter_make_fx(true_fn, pre_dispatch)(*operands)
|
||||
false_graph = reenter_make_fx(false_fn, pre_dispatch)(*operands)
|
||||
|
||||
true_outs = []
|
||||
false_outs = []
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from torch._functorch.aot_autograd import AOTConfig, create_joint, from_fun
|
|||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
reenter_make_fx,
|
||||
UnsupportedAliasMutationException,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
|
|
@ -228,8 +229,9 @@ def trace_map(proxy_mode, func_overload, f, xs, pos_args):
|
|||
|
||||
example_input = _unstack_pytree(xs)[0]
|
||||
body_graph = f
|
||||
if not isinstance(body_graph, torch.fx.GraphModule):
|
||||
body_graph = make_fx(body_graph)(*example_input, *pos_args)
|
||||
|
||||
pre_dispatch = getattr(proxy_mode, "pre_dispatch", False)
|
||||
body_graph = reenter_make_fx(body_graph, pre_dispatch)(*example_input, *pos_args)
|
||||
|
||||
next_name = None
|
||||
i = 0
|
||||
|
|
|
|||
|
|
@ -76,6 +76,18 @@ def _maybe_run_with_interpreter(fn):
|
|||
return maybe_interpreted_fn
|
||||
|
||||
|
||||
# We'll use the current decomposition table to make sure operators in subgraphs are
|
||||
# decomposed properly.
|
||||
# We also need to maybe run with interpreter for propagating stack_trace
|
||||
def reenter_make_fx(fn, pre_dispatch=False):
|
||||
decomp_table = torch.fx.experimental.proxy_tensor.CURRENT_DECOMPOSITION_TABLE
|
||||
return make_fx(
|
||||
_maybe_run_with_interpreter(fn),
|
||||
decomposition_table=decomp_table,
|
||||
pre_dispatch=pre_dispatch,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _set_compilation_env():
|
||||
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
|
||||
|
|
|
|||
|
|
@ -6,16 +6,15 @@ from torch._C import DispatchKey
|
|||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
_maybe_run_with_interpreter,
|
||||
_set_compilation_env,
|
||||
autograd_not_implemented,
|
||||
reenter_make_fx,
|
||||
UnsupportedAliasMutationException,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
disable_proxy_modes_tracing,
|
||||
make_fx,
|
||||
ProxyTorchDispatchMode,
|
||||
track_tensor_tree,
|
||||
)
|
||||
|
|
@ -159,14 +158,9 @@ while_loop_op.py_impl(DispatchKey.Autograd)(
|
|||
def while_loop_tracing(mode, cond_fn, body_fn, operands):
|
||||
def _trace_while_loop(proxy_mode, while_loop_op, cond_fn, body_fn, operands):
|
||||
pre_dispatch = getattr(proxy_mode, "pre_dispatch", False)
|
||||
|
||||
with disable_proxy_modes_tracing():
|
||||
cond_graph = make_fx(
|
||||
_maybe_run_with_interpreter(cond_fn), pre_dispatch=pre_dispatch
|
||||
)(*operands)
|
||||
body_graph = make_fx(
|
||||
_maybe_run_with_interpreter(body_fn), pre_dispatch=pre_dispatch
|
||||
)(*operands)
|
||||
cond_graph = reenter_make_fx(cond_fn, pre_dispatch)(*operands)
|
||||
body_graph = reenter_make_fx(body_fn, pre_dispatch)(*operands)
|
||||
|
||||
next_name = None
|
||||
i = 0
|
||||
|
|
|
|||
Loading…
Reference in a new issue