[HigherOrderOp] move some common utils in cond to utils.py (#116721)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116721
Approved by: https://github.com/zou3519
This commit is contained in:
ydwu4 2024-01-05 17:10:33 -08:00 committed by PyTorch MergeBot
parent 77cfacab55
commit 113f0749f5
3 changed files with 128 additions and 120 deletions

View file

@ -1,9 +1,5 @@
from contextlib import contextmanager
from dataclasses import dataclass
import torch
import torch._subclasses.functional_tensor
import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
@ -16,7 +12,15 @@ from torch._C._functorch import (
)
from torch._functorch.utils import exposed_in
from torch._higher_order_ops.utils import autograd_not_implemented
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
_maybe_run_with_interpreter,
_set_compilation_env,
autograd_not_implemented,
UnsupportedAliasMutationException,
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
@ -26,27 +30,9 @@ from torch.fx.experimental.proxy_tensor import (
track_tensor_tree,
)
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._python_dispatch import _get_current_dispatch_mode
@contextmanager
def _set_compilation_env():
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
try:
# We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
# once we are confident fx tracing works with dynamo.
torch.fx._symbolic_trace._is_fx_tracing_flag = False
yield
finally:
torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing
@dataclass
class UnsupportedAliasMutationException(RuntimeError):
reason: str
@exposed_in("torch")
def cond(pred, true_fn, false_fn, operands):
r"""
@ -160,18 +146,6 @@ In order to do this, we need implementations for each of the dispatch keys.
cond_op = HigherOrderOperator("cond")
def _maybe_run_with_interpreter(fn):
maybe_interpreted_fn = fn
if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta():
# Running graph with interpreter is needed for propagating the stack_trace
def graph_with_interpreter(*args):
with fx_traceback.preserve_node_meta():
return torch.fx.Interpreter(fn).run(*args)
maybe_interpreted_fn = graph_with_interpreter
return maybe_interpreted_fn
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
assert isinstance(
operands, (list, tuple)
@ -304,90 +278,6 @@ def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands):
return true_outs
def _has_potential_branch_input_mutation(branch, inputs):
"""
Dispatch-trace the branch with inputs and check if
producing graph has mutable op on the input. This is
bit restrictive as the branch must be traceable.
"""
try:
gm = make_fx(branch)(*inputs)
except UnsupportedAliasMutationException:
# this can happen when nested cond_op is
# functionalized
return True
except Exception as e:
raise e
def _detect_input_mutation(gm):
input_nodes = set()
for node in gm.graph.nodes:
if node.op == "placeholder":
input_nodes.add(node)
if node.op == "call_function":
target = node.target
if (
isinstance(target, torch._ops.OpOverload)
and target._schema.is_mutable
):
for arg in node.args:
if arg in input_nodes:
return True
for _, module in gm.named_children():
if isinstance(module, torch.fx.GraphModule):
if _detect_input_mutation(module):
return True
return False
return _detect_input_mutation(gm)
def _has_potential_branch_input_alias(branch, inputs):
"""
Dispatch-trace the branch with inputs and check if
producing graph has output aliasing the branch input. This is
bit restrictive as the branch must be traceable.
"""
try:
gm = make_fx(branch)(*inputs)
except UnsupportedAliasMutationException:
# this can happen when nested cond_op is
# functionalized
return True
except Exception as e:
raise e
def _detect_input_alias(gm):
input_storages = set()
for node in gm.graph.nodes:
# We need to check existence of "val" because we reuse the logic here
# for map operator, where num_mapped_args is a scalar
# and doesn't have a "val" meta.
if node.op == "placeholder" and "val" in node.meta:
input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage()))
if node.op == "output":
def check_alias(out):
if out is not None and "val" in out.meta:
out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
return out_storage in input_storages
return False
if any(pytree.tree_leaves(pytree.tree_map(check_alias, node.args))):
return True
for _, module in gm.named_children():
if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module):
return True
return False
return _detect_input_alias(gm)
@cond_op.py_functionalize_impl
def cond_func(ctx, pred, true_fn, false_fn, inputs):
unwrapped_inputs = ctx.unwrap_tensors(inputs)

View file

@ -6,7 +6,7 @@ from torch._C import DispatchKey
from torch._dispatch.python import suspend_functionalization
from torch._functorch.aot_autograd import AOTConfig, create_joint, from_fun
from torch._higher_order_ops.cond import (
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
UnsupportedAliasMutationException,

View file

@ -1,8 +1,18 @@
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable
import torch
import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch._ops import HigherOrderOperator
from torch.fx.experimental.proxy_tensor import make_fx
from torch.multiprocessing.reductions import StorageWeakRef
@dataclass
class UnsupportedAliasMutationException(RuntimeError):
reason: str
def autograd_not_implemented_inner(
@ -52,3 +62,111 @@ def autograd_not_implemented(op: HigherOrderOperator, deferred_error: bool) -> C
return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
return inner
def _maybe_run_with_interpreter(fn):
maybe_interpreted_fn = fn
if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta():
# Running graph with interpreter is needed for propagating the stack_trace
def graph_with_interpreter(*args):
with fx_traceback.preserve_node_meta():
return torch.fx.Interpreter(fn).run(*args)
maybe_interpreted_fn = graph_with_interpreter
return maybe_interpreted_fn
@contextmanager
def _set_compilation_env():
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
try:
# We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
# once we are confident fx tracing works with dynamo.
torch.fx._symbolic_trace._is_fx_tracing_flag = False
yield
finally:
torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing
def _has_potential_branch_input_mutation(branch, inputs):
"""
Dispatch-trace the branch with inputs and check if
producing graph has mutable op on the input. This is
bit restrictive as the branch must be traceable.
"""
try:
gm = make_fx(branch)(*inputs)
except UnsupportedAliasMutationException:
# this can happen when nested cond_op is
# functionalized
return True
except Exception as e:
raise e
def _detect_input_mutation(gm):
input_nodes = set()
for node in gm.graph.nodes:
if node.op == "placeholder":
input_nodes.add(node)
if node.op == "call_function":
target = node.target
if (
isinstance(target, torch._ops.OpOverload)
and target._schema.is_mutable
):
for arg in node.args:
if arg in input_nodes:
return True
for _, module in gm.named_children():
if isinstance(module, torch.fx.GraphModule):
if _detect_input_mutation(module):
return True
return False
return _detect_input_mutation(gm)
def _has_potential_branch_input_alias(branch, inputs):
"""
Dispatch-trace the branch with inputs and check if
producing graph has output aliasing the branch input. This is
bit restrictive as the branch must be traceable.
"""
try:
gm = make_fx(branch)(*inputs)
except UnsupportedAliasMutationException:
# this can happen when nested cond_op is
# functionalized
return True
except Exception as e:
raise e
def _detect_input_alias(gm):
input_storages = set()
for node in gm.graph.nodes:
# We need to check existence of "val" because we reuse the logic here
# for map operator, where num_mapped_args is a scalar
# and doesn't have a "val" meta.
if node.op == "placeholder" and "val" in node.meta:
input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage()))
if node.op == "output":
def check_alias(out):
if out is not None and "val" in out.meta:
out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
return out_storage in input_storages
return False
if any(pytree.tree_leaves(pytree.tree_map(check_alias, node.args))):
return True
for _, module in gm.named_children():
if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module):
return True
return False
return _detect_input_alias(gm)