From 113f0749f5b45e2f9e4a4c07b43ca413e873b3f3 Mon Sep 17 00:00:00 2001 From: ydwu4 Date: Fri, 5 Jan 2024 17:10:33 -0800 Subject: [PATCH] [HigherOrderOp] move some common utils in cond to utils.py (#116721) Pull Request resolved: https://github.com/pytorch/pytorch/pull/116721 Approved by: https://github.com/zou3519 --- torch/_higher_order_ops/cond.py | 128 +++---------------------------- torch/_higher_order_ops/map.py | 2 +- torch/_higher_order_ops/utils.py | 118 ++++++++++++++++++++++++++++ 3 files changed, 128 insertions(+), 120 deletions(-) diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index 4f292bf4c92..d5a62a6d90a 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -1,9 +1,5 @@ -from contextlib import contextmanager -from dataclasses import dataclass - import torch import torch._subclasses.functional_tensor -import torch.fx.traceback as fx_traceback import torch.utils._pytree as pytree @@ -16,7 +12,15 @@ from torch._C._functorch import ( ) from torch._functorch.utils import exposed_in -from torch._higher_order_ops.utils import autograd_not_implemented +from torch._higher_order_ops.utils import ( + _has_potential_branch_input_alias, + _has_potential_branch_input_mutation, + _maybe_run_with_interpreter, + _set_compilation_env, + autograd_not_implemented, + UnsupportedAliasMutationException, +) + from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( @@ -26,27 +30,9 @@ from torch.fx.experimental.proxy_tensor import ( track_tensor_tree, ) from torch.fx.passes.shape_prop import _extract_tensor_metadata -from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._python_dispatch import _get_current_dispatch_mode -@contextmanager -def _set_compilation_env(): - _old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag - try: - # We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo - # once we are confident fx tracing works with dynamo. - torch.fx._symbolic_trace._is_fx_tracing_flag = False - yield - finally: - torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing - - -@dataclass -class UnsupportedAliasMutationException(RuntimeError): - reason: str - - @exposed_in("torch") def cond(pred, true_fn, false_fn, operands): r""" @@ -160,18 +146,6 @@ In order to do this, we need implementations for each of the dispatch keys. cond_op = HigherOrderOperator("cond") -def _maybe_run_with_interpreter(fn): - maybe_interpreted_fn = fn - if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta(): - # Running graph with interpreter is needed for propagating the stack_trace - def graph_with_interpreter(*args): - with fx_traceback.preserve_node_meta(): - return torch.fx.Interpreter(fn).run(*args) - - maybe_interpreted_fn = graph_with_interpreter - return maybe_interpreted_fn - - def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): assert isinstance( operands, (list, tuple) @@ -304,90 +278,6 @@ def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands): return true_outs -def _has_potential_branch_input_mutation(branch, inputs): - """ - Dispatch-trace the branch with inputs and check if - producing graph has mutable op on the input. This is - bit restrictive as the branch must be traceable. - """ - try: - gm = make_fx(branch)(*inputs) - except UnsupportedAliasMutationException: - # this can happen when nested cond_op is - # functionalized - return True - except Exception as e: - raise e - - def _detect_input_mutation(gm): - input_nodes = set() - for node in gm.graph.nodes: - if node.op == "placeholder": - input_nodes.add(node) - if node.op == "call_function": - target = node.target - if ( - isinstance(target, torch._ops.OpOverload) - and target._schema.is_mutable - ): - for arg in node.args: - if arg in input_nodes: - return True - - for _, module in gm.named_children(): - if isinstance(module, torch.fx.GraphModule): - if _detect_input_mutation(module): - return True - - return False - - return _detect_input_mutation(gm) - - -def _has_potential_branch_input_alias(branch, inputs): - """ - Dispatch-trace the branch with inputs and check if - producing graph has output aliasing the branch input. This is - bit restrictive as the branch must be traceable. - """ - try: - gm = make_fx(branch)(*inputs) - - except UnsupportedAliasMutationException: - # this can happen when nested cond_op is - # functionalized - return True - except Exception as e: - raise e - - def _detect_input_alias(gm): - input_storages = set() - for node in gm.graph.nodes: - # We need to check existence of "val" because we reuse the logic here - # for map operator, where num_mapped_args is a scalar - # and doesn't have a "val" meta. - if node.op == "placeholder" and "val" in node.meta: - input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage())) - if node.op == "output": - - def check_alias(out): - if out is not None and "val" in out.meta: - out_storage = StorageWeakRef(out.meta["val"]._typed_storage()) - return out_storage in input_storages - return False - - if any(pytree.tree_leaves(pytree.tree_map(check_alias, node.args))): - return True - - for _, module in gm.named_children(): - if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module): - return True - - return False - - return _detect_input_alias(gm) - - @cond_op.py_functionalize_impl def cond_func(ctx, pred, true_fn, false_fn, inputs): unwrapped_inputs = ctx.unwrap_tensors(inputs) diff --git a/torch/_higher_order_ops/map.py b/torch/_higher_order_ops/map.py index 6ecbd60bb2e..0c93c4bf486 100644 --- a/torch/_higher_order_ops/map.py +++ b/torch/_higher_order_ops/map.py @@ -6,7 +6,7 @@ from torch._C import DispatchKey from torch._dispatch.python import suspend_functionalization from torch._functorch.aot_autograd import AOTConfig, create_joint, from_fun -from torch._higher_order_ops.cond import ( +from torch._higher_order_ops.utils import ( _has_potential_branch_input_alias, _has_potential_branch_input_mutation, UnsupportedAliasMutationException, diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index a26758b74cc..b100debc618 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -1,8 +1,18 @@ +from contextlib import contextmanager +from dataclasses import dataclass from typing import Any, Callable import torch +import torch.fx.traceback as fx_traceback import torch.utils._pytree as pytree from torch._ops import HigherOrderOperator +from torch.fx.experimental.proxy_tensor import make_fx +from torch.multiprocessing.reductions import StorageWeakRef + + +@dataclass +class UnsupportedAliasMutationException(RuntimeError): + reason: str def autograd_not_implemented_inner( @@ -52,3 +62,111 @@ def autograd_not_implemented(op: HigherOrderOperator, deferred_error: bool) -> C return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs) return inner + + +def _maybe_run_with_interpreter(fn): + maybe_interpreted_fn = fn + if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta(): + # Running graph with interpreter is needed for propagating the stack_trace + def graph_with_interpreter(*args): + with fx_traceback.preserve_node_meta(): + return torch.fx.Interpreter(fn).run(*args) + + maybe_interpreted_fn = graph_with_interpreter + return maybe_interpreted_fn + + +@contextmanager +def _set_compilation_env(): + _old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag + try: + # We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo + # once we are confident fx tracing works with dynamo. + torch.fx._symbolic_trace._is_fx_tracing_flag = False + yield + finally: + torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing + + +def _has_potential_branch_input_mutation(branch, inputs): + """ + Dispatch-trace the branch with inputs and check if + producing graph has mutable op on the input. This is + bit restrictive as the branch must be traceable. + """ + try: + gm = make_fx(branch)(*inputs) + except UnsupportedAliasMutationException: + # this can happen when nested cond_op is + # functionalized + return True + except Exception as e: + raise e + + def _detect_input_mutation(gm): + input_nodes = set() + for node in gm.graph.nodes: + if node.op == "placeholder": + input_nodes.add(node) + if node.op == "call_function": + target = node.target + if ( + isinstance(target, torch._ops.OpOverload) + and target._schema.is_mutable + ): + for arg in node.args: + if arg in input_nodes: + return True + + for _, module in gm.named_children(): + if isinstance(module, torch.fx.GraphModule): + if _detect_input_mutation(module): + return True + + return False + + return _detect_input_mutation(gm) + + +def _has_potential_branch_input_alias(branch, inputs): + """ + Dispatch-trace the branch with inputs and check if + producing graph has output aliasing the branch input. This is + bit restrictive as the branch must be traceable. + """ + try: + gm = make_fx(branch)(*inputs) + + except UnsupportedAliasMutationException: + # this can happen when nested cond_op is + # functionalized + return True + except Exception as e: + raise e + + def _detect_input_alias(gm): + input_storages = set() + for node in gm.graph.nodes: + # We need to check existence of "val" because we reuse the logic here + # for map operator, where num_mapped_args is a scalar + # and doesn't have a "val" meta. + if node.op == "placeholder" and "val" in node.meta: + input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage())) + if node.op == "output": + + def check_alias(out): + if out is not None and "val" in out.meta: + out_storage = StorageWeakRef(out.meta["val"]._typed_storage()) + return out_storage in input_storages + return False + + if any(pytree.tree_leaves(pytree.tree_map(check_alias, node.args))): + return True + + for _, module in gm.named_children(): + if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module): + return True + + return False + + return _detect_input_alias(gm)