mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[HigherOrderOp] move some common utils in cond to utils.py (#116721)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116721 Approved by: https://github.com/zou3519
This commit is contained in:
parent
77cfacab55
commit
113f0749f5
3 changed files with 128 additions and 120 deletions
|
|
@ -1,9 +1,5 @@
|
|||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch._subclasses.functional_tensor
|
||||
import torch.fx.traceback as fx_traceback
|
||||
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
|
|
@ -16,7 +12,15 @@ from torch._C._functorch import (
|
|||
)
|
||||
from torch._functorch.utils import exposed_in
|
||||
|
||||
from torch._higher_order_ops.utils import autograd_not_implemented
|
||||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
_maybe_run_with_interpreter,
|
||||
_set_compilation_env,
|
||||
autograd_not_implemented,
|
||||
UnsupportedAliasMutationException,
|
||||
)
|
||||
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
|
|
@ -26,27 +30,9 @@ from torch.fx.experimental.proxy_tensor import (
|
|||
track_tensor_tree,
|
||||
)
|
||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _set_compilation_env():
|
||||
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
|
||||
try:
|
||||
# We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
|
||||
# once we are confident fx tracing works with dynamo.
|
||||
torch.fx._symbolic_trace._is_fx_tracing_flag = False
|
||||
yield
|
||||
finally:
|
||||
torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnsupportedAliasMutationException(RuntimeError):
|
||||
reason: str
|
||||
|
||||
|
||||
@exposed_in("torch")
|
||||
def cond(pred, true_fn, false_fn, operands):
|
||||
r"""
|
||||
|
|
@ -160,18 +146,6 @@ In order to do this, we need implementations for each of the dispatch keys.
|
|||
cond_op = HigherOrderOperator("cond")
|
||||
|
||||
|
||||
def _maybe_run_with_interpreter(fn):
|
||||
maybe_interpreted_fn = fn
|
||||
if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta():
|
||||
# Running graph with interpreter is needed for propagating the stack_trace
|
||||
def graph_with_interpreter(*args):
|
||||
with fx_traceback.preserve_node_meta():
|
||||
return torch.fx.Interpreter(fn).run(*args)
|
||||
|
||||
maybe_interpreted_fn = graph_with_interpreter
|
||||
return maybe_interpreted_fn
|
||||
|
||||
|
||||
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||
assert isinstance(
|
||||
operands, (list, tuple)
|
||||
|
|
@ -304,90 +278,6 @@ def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands):
|
|||
return true_outs
|
||||
|
||||
|
||||
def _has_potential_branch_input_mutation(branch, inputs):
|
||||
"""
|
||||
Dispatch-trace the branch with inputs and check if
|
||||
producing graph has mutable op on the input. This is
|
||||
bit restrictive as the branch must be traceable.
|
||||
"""
|
||||
try:
|
||||
gm = make_fx(branch)(*inputs)
|
||||
except UnsupportedAliasMutationException:
|
||||
# this can happen when nested cond_op is
|
||||
# functionalized
|
||||
return True
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _detect_input_mutation(gm):
|
||||
input_nodes = set()
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
input_nodes.add(node)
|
||||
if node.op == "call_function":
|
||||
target = node.target
|
||||
if (
|
||||
isinstance(target, torch._ops.OpOverload)
|
||||
and target._schema.is_mutable
|
||||
):
|
||||
for arg in node.args:
|
||||
if arg in input_nodes:
|
||||
return True
|
||||
|
||||
for _, module in gm.named_children():
|
||||
if isinstance(module, torch.fx.GraphModule):
|
||||
if _detect_input_mutation(module):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return _detect_input_mutation(gm)
|
||||
|
||||
|
||||
def _has_potential_branch_input_alias(branch, inputs):
|
||||
"""
|
||||
Dispatch-trace the branch with inputs and check if
|
||||
producing graph has output aliasing the branch input. This is
|
||||
bit restrictive as the branch must be traceable.
|
||||
"""
|
||||
try:
|
||||
gm = make_fx(branch)(*inputs)
|
||||
|
||||
except UnsupportedAliasMutationException:
|
||||
# this can happen when nested cond_op is
|
||||
# functionalized
|
||||
return True
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _detect_input_alias(gm):
|
||||
input_storages = set()
|
||||
for node in gm.graph.nodes:
|
||||
# We need to check existence of "val" because we reuse the logic here
|
||||
# for map operator, where num_mapped_args is a scalar
|
||||
# and doesn't have a "val" meta.
|
||||
if node.op == "placeholder" and "val" in node.meta:
|
||||
input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage()))
|
||||
if node.op == "output":
|
||||
|
||||
def check_alias(out):
|
||||
if out is not None and "val" in out.meta:
|
||||
out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
|
||||
return out_storage in input_storages
|
||||
return False
|
||||
|
||||
if any(pytree.tree_leaves(pytree.tree_map(check_alias, node.args))):
|
||||
return True
|
||||
|
||||
for _, module in gm.named_children():
|
||||
if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return _detect_input_alias(gm)
|
||||
|
||||
|
||||
@cond_op.py_functionalize_impl
|
||||
def cond_func(ctx, pred, true_fn, false_fn, inputs):
|
||||
unwrapped_inputs = ctx.unwrap_tensors(inputs)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from torch._C import DispatchKey
|
|||
from torch._dispatch.python import suspend_functionalization
|
||||
from torch._functorch.aot_autograd import AOTConfig, create_joint, from_fun
|
||||
|
||||
from torch._higher_order_ops.cond import (
|
||||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
UnsupportedAliasMutationException,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,18 @@
|
|||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
import torch.fx.traceback as fx_traceback
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnsupportedAliasMutationException(RuntimeError):
|
||||
reason: str
|
||||
|
||||
|
||||
def autograd_not_implemented_inner(
|
||||
|
|
@ -52,3 +62,111 @@ def autograd_not_implemented(op: HigherOrderOperator, deferred_error: bool) -> C
|
|||
return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def _maybe_run_with_interpreter(fn):
|
||||
maybe_interpreted_fn = fn
|
||||
if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta():
|
||||
# Running graph with interpreter is needed for propagating the stack_trace
|
||||
def graph_with_interpreter(*args):
|
||||
with fx_traceback.preserve_node_meta():
|
||||
return torch.fx.Interpreter(fn).run(*args)
|
||||
|
||||
maybe_interpreted_fn = graph_with_interpreter
|
||||
return maybe_interpreted_fn
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _set_compilation_env():
|
||||
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
|
||||
try:
|
||||
# We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
|
||||
# once we are confident fx tracing works with dynamo.
|
||||
torch.fx._symbolic_trace._is_fx_tracing_flag = False
|
||||
yield
|
||||
finally:
|
||||
torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing
|
||||
|
||||
|
||||
def _has_potential_branch_input_mutation(branch, inputs):
|
||||
"""
|
||||
Dispatch-trace the branch with inputs and check if
|
||||
producing graph has mutable op on the input. This is
|
||||
bit restrictive as the branch must be traceable.
|
||||
"""
|
||||
try:
|
||||
gm = make_fx(branch)(*inputs)
|
||||
except UnsupportedAliasMutationException:
|
||||
# this can happen when nested cond_op is
|
||||
# functionalized
|
||||
return True
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _detect_input_mutation(gm):
|
||||
input_nodes = set()
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
input_nodes.add(node)
|
||||
if node.op == "call_function":
|
||||
target = node.target
|
||||
if (
|
||||
isinstance(target, torch._ops.OpOverload)
|
||||
and target._schema.is_mutable
|
||||
):
|
||||
for arg in node.args:
|
||||
if arg in input_nodes:
|
||||
return True
|
||||
|
||||
for _, module in gm.named_children():
|
||||
if isinstance(module, torch.fx.GraphModule):
|
||||
if _detect_input_mutation(module):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return _detect_input_mutation(gm)
|
||||
|
||||
|
||||
def _has_potential_branch_input_alias(branch, inputs):
|
||||
"""
|
||||
Dispatch-trace the branch with inputs and check if
|
||||
producing graph has output aliasing the branch input. This is
|
||||
bit restrictive as the branch must be traceable.
|
||||
"""
|
||||
try:
|
||||
gm = make_fx(branch)(*inputs)
|
||||
|
||||
except UnsupportedAliasMutationException:
|
||||
# this can happen when nested cond_op is
|
||||
# functionalized
|
||||
return True
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _detect_input_alias(gm):
|
||||
input_storages = set()
|
||||
for node in gm.graph.nodes:
|
||||
# We need to check existence of "val" because we reuse the logic here
|
||||
# for map operator, where num_mapped_args is a scalar
|
||||
# and doesn't have a "val" meta.
|
||||
if node.op == "placeholder" and "val" in node.meta:
|
||||
input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage()))
|
||||
if node.op == "output":
|
||||
|
||||
def check_alias(out):
|
||||
if out is not None and "val" in out.meta:
|
||||
out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
|
||||
return out_storage in input_storages
|
||||
return False
|
||||
|
||||
if any(pytree.tree_leaves(pytree.tree_map(check_alias, node.args))):
|
||||
return True
|
||||
|
||||
for _, module in gm.named_children():
|
||||
if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return _detect_input_alias(gm)
|
||||
|
|
|
|||
Loading…
Reference in a new issue