diff --git a/functorch/experimental/_cond.py b/functorch/experimental/_cond.py index dec5341dacc..95625866bfb 100644 --- a/functorch/experimental/_cond.py +++ b/functorch/experimental/_cond.py @@ -36,7 +36,6 @@ In order to do this, we need implementations for each of the dispatch keys. """ cond = HigherOrderOperator("cond") - def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): assert isinstance(operands, (list, tuple)), "Cond operands must be a list or tuple of tensors" assert all(isinstance(o, torch.Tensor) for o in operands), "Cond operands must be a list of tensors" @@ -159,7 +158,6 @@ def cond_fake_tensor_mode(pred, true_fn, false_fn, operands): return true_outs - def _has_potential_branch_input_mutation(branch, inputs): """ Dispatch-trace the branch with inputs and check if @@ -175,18 +173,27 @@ def _has_potential_branch_input_mutation(branch, inputs): except Exception as e: raise e - input_nodes = set() - for node in gm.graph.nodes: - if node.op == "placeholder": - input_nodes.add(node) - if node.op == "call_function": - target = node.target - if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable: - for arg in node.args: - if arg in input_nodes: - return True + def _detect_input_mutation(gm): + input_nodes = set() + for node in gm.graph.nodes: + if node.op == "placeholder": + input_nodes.add(node) + if node.op == "call_function": + target = node.target + if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable: + for arg in node.args: + if arg in input_nodes: + return True + + for _, module in gm.named_children(): + if isinstance(module, torch.fx.GraphModule): + if _detect_input_mutation(module): + return True + + return False + + return _detect_input_mutation(gm) - return False def _has_potential_branch_input_alias(branch, inputs): """ @@ -204,18 +211,50 @@ def _has_potential_branch_input_alias(branch, inputs): except Exception as e: raise e - input_storages = set() - for node in gm.graph.nodes: - if node.op == "placeholder": - input_storages.add(StorageWeakRef(node.meta['val']._typed_storage())) - if node.op == "output": - for out in node.args: - out_storage = StorageWeakRef(out.meta["val"]._typed_storage()) - if out_storage in input_storages: - return True + def _detect_input_alias(gm): + input_storages = set() + for node in gm.graph.nodes: + if node.op == "placeholder": + input_storages.add(StorageWeakRef(node.meta['val']._typed_storage())) + if node.op == "output": + for out in node.args: + out_storage = StorageWeakRef(out.meta["val"]._typed_storage()) + if out_storage in input_storages: + return True - return False + for _, module in gm.named_children(): + if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module): + return True + return False + + return _detect_input_alias(gm) + + +@cond.py_impl(DispatchKey.Functionalize) +def cond_func(pred, true_fn, false_fn, inputs): + reapply_views = torch._C._functionalization_reapply_views_tls() + unwrapped_inputs = _unwrap_all_tensors_from_functional(inputs, reapply_views=reapply_views) + unwrapped_pred = _unwrap_all_tensors_from_functional(pred, reapply_views=reapply_views) + mode = 'mutations_and_views' if reapply_views else 'mutations' + guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)) + try: + functional_true = functionalize(true_fn, remove=mode) + functional_false = functionalize(false_fn, remove=mode) + for branch in [true_fn, false_fn]: + if _has_potential_branch_input_mutation(branch, unwrapped_inputs): + raise UnsupportedAliasMutationException("One of torch.cond branch " + "might be modifying the input!") + + if _has_potential_branch_input_alias(branch, unwrapped_inputs): + raise UnsupportedAliasMutationException("One of torch.cond branch " + "might be aliasing the input!") + + cond_return = cond(unwrapped_pred, functional_true, functional_false, unwrapped_inputs) + return _wrap_all_tensors_to_functional(cond_return, level=0) + + finally: + del guard @cond.py_impl(torch._C._functorch.TransformType.Functionalize) diff --git a/functorch/experimental/_map.py b/functorch/experimental/_map.py index 08fa82d73d3..7f75cd108b9 100644 --- a/functorch/experimental/_map.py +++ b/functorch/experimental/_map.py @@ -91,6 +91,34 @@ def map_fake_tensor_mode(f, xs, *args): outs = [f(x, *args) for x in xs] return outs[0].new_empty([xs.shape[0], *outs[0].shape]) + +@map.py_impl(DispatchKey.Functionalize) +def map_func(f, xs, *args): + reapply_views = torch._C._functionalization_reapply_views_tls() + unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views) + unwrapped_args = _unwrap_all_tensors_from_functional(args, reapply_views=reapply_views) + mode = 'mutations_and_views' if reapply_views else 'mutations' + + guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)) + try: + functional_map_fn = functionalize(f, remove=mode) + inputs = (unwrapped_xs,) + unwrapped_args + + if _has_potential_branch_input_mutation(f, inputs): + raise UnsupportedAliasMutationException( + "torch.map is mutating the input!" + ) + + if _has_potential_branch_input_alias(f, inputs): + raise UnsupportedAliasMutationException( + "torch.map is aliasing the input!" + ) + + map_return = map(functional_map_fn, unwrapped_xs, *unwrapped_args) + return _wrap_all_tensors_to_functional(map_return, level=0) + finally: + del guard + @map.py_impl(torch._C._functorch.TransformType.Functionalize) def map_functionalize(interpreter, f, xs, *args): """ diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 2e64f3966c0..818b1bf5c9f 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -12,8 +12,6 @@ from enum import Enum from typing import Dict, List, Sequence from unittest.mock import patch -import pytest - import torch import torch._dynamo import torch._dynamo.test_case @@ -2335,7 +2333,6 @@ def forward(self, x): preserved = True self.assertTrue(preserved) - @pytest.mark.xfail(reason="Saving example_fake_inputs breaks the serialization") @config.patch( dynamic_shapes=True, capture_dynamic_output_shape_ops=True, @@ -2655,6 +2652,113 @@ def forward(self, x): self.assertEqual(f_correct(torch.ones(6, 4)), gm(torch.ones(6, 4))) + @config.patch(dynamic_shapes=True) + def test_functionalize(self): + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buffer1", torch.ones(6, 2)) + + def forward(self, x): + x.add_(2) + return x.sum() + self.buffer1.sum() + + example_inputs = (torch.ones(1, 2, 3),) + gm, _ = torch._dynamo.export( + Foo(), + *example_inputs, + aten_graph=True, + tracing_mode="symbolic", + functionalize=True, + ) + + count = 0 + for node in gm.graph.nodes: + if node.target == torch.ops.aten.add_.Tensor: + count += 1 + self.assertEqual(count, 0) + test_inp = (torch.ones(1, 2, 3),) + test_inp_v2 = (torch.ones(1, 2, 3),) + self.assertEqual(gm(*test_inp), Foo()(*test_inp_v2)) + + @config.patch(dynamic_shapes=True) + def test_not_functionalize(self): + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buffer1", torch.ones(6, 2)) + + def forward(self, x): + x.add_(2) + return x.sum() + self.buffer1.sum() + + example_inputs = (torch.ones(1, 2, 3),) + gm, _ = torch._dynamo.export( + Foo(), + *example_inputs, + aten_graph=True, + tracing_mode="symbolic", + functionalize=False, + ) + count = 0 + for node in gm.graph.nodes: + if node.target == torch.ops.aten.add_.Tensor: + count += 1 + self.assertEqual(count, 1) + test_inp = (torch.ones(1, 2, 3),) + test_inp_v2 = (torch.ones(1, 2, 3),) + self.assertEqual(gm(*test_inp), Foo()(*test_inp_v2)) + + @config.patch(dynamic_shapes=True, assume_static_by_default=False) + def test_functionalize_cond(self): + def foo(x): + def true_true_fn(x): + return x.sum() + 6 + + def true_false_fn(x): + return x.sum() + 9 + + def true_fn(x): + return cond(x.shape[0] > 6, true_true_fn, true_false_fn, [x]) + + def false_fn(x): + return x.sum() - 1 + + return cond(x.shape[0] > 5, true_fn, false_fn, [x]) + + example_inputs = (torch.ones(5, 2, 3),) + gm, _ = torch._dynamo.export( + foo, + *example_inputs, + aten_graph=True, + tracing_mode="symbolic", + functionalize=True, + ) + self.assertEqual(gm(torch.ones(7, 2, 3)), foo(torch.ones(7, 2, 3))) + + @config.patch(dynamic_shapes=True) + def test_functionalize_simple(self): + def foo(x): + def true_fn(x): + return x.sum() + 1 + + def false_fn(x): + return x.sum() - 1 + + return cond(x.shape[0] > 5, true_fn, false_fn, [x]) + + example_inputs = (torch.ones(5, 2, 3),) + gm, _ = torch._dynamo.export( + foo, + *example_inputs, + aten_graph=True, + tracing_mode="symbolic", + functionalize=True, + ) + self.assertEqual(gm.true_graph_0(torch.ones(6, 4)), torch.ones(6, 4).sum() + 1) + self.assertEqual(gm.false_graph_0(torch.ones(6, 4)), torch.ones(6, 4).sum() - 1) + + @config.patch(dynamic_shapes=True) def test_round_dynamic_shapes(self): def f(x): return x[: round(x.shape[0] / 2)] diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 51dfc722c33..c01ab9cccb0 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -1,7 +1,10 @@ # Owner(s): ["module: functorch"] +import functools import unittest import torch +import torch.utils._pytree as pytree +from torch._functorch.aot_autograd import from_fun, to_fun from functorch.experimental import control_flow from functorch.experimental.control_flow import cond from functorch.experimental.control_flow import UnsupportedAliasMutationException @@ -298,6 +301,125 @@ class TestControlFlowTraced(TestCase): with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"): make_fx(torch.func.functionalize(f))(*example_inputs) + def test_cond_functionalized_nested_input_mutation_with_aot_func(self): + def true_true_fn(x): + x.add_(4) + return x.sin().max() + + def true_false_fn(x): + return x.cos().min() + + def true_fn(x): + pred = x.shape[0] == 1 + return cond(pred, true_true_fn, true_false_fn, [x]) + + def false_fn(x): + return x.sum() + + def f(x): + pred = x.shape[0] == 1 + return cond(pred, true_fn, false_fn, [x]) + + example_input = torch.ones(4, 5) + example_input_func = to_fun(example_input) + torch._enable_functionalization(reapply_views=False) + try: + with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"): + f(example_input_func) + + with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"): + make_fx(f)(example_input_func) + finally: + torch._disable_functionalization() + + def f_wrapper(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + torch._enable_functionalization(reapply_views=False) + try: + return func(*args, **kwargs) + finally: + torch._disable_functionalization() + return wrapper + + with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"): + make_fx(f_wrapper(f))(example_input_func) + + def test_cond_functionalized_input_aliasing_with_aot_func(self): + def true_fn(x): + return x + + def false_fn(x): + view_x = x.view(x.shape) + return view_x + + def f(x): + pred = x.shape[0] == 4 + return cond(pred, true_fn, false_fn, [x]) + + example_input = torch.ones(5, 5) + example_input_func = to_fun(example_input) + torch._enable_functionalization(reapply_views=False) + try: + with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch might be aliasing"): + f(example_input_func) + finally: + torch._disable_functionalization() + + def f_wrapper(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + torch._enable_functionalization(reapply_views=False) + try: + func_args = pytree.tree_map(to_fun, args) + func_kwargs = pytree.tree_map(to_fun, kwargs) + return func(*func_args, **func_kwargs) + finally: + torch._disable_functionalization() + return wrapper + + with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch might be aliasing"): + make_fx(f_wrapper(f))(example_input) + + def test_cond_functionalized_aot_func_check_functional(self): + def true_fn(x): + return x.cos() + + def false_fn(x): + y = x.sin() + y.add_(5) + return y + + def f(x): + pred = x.shape[0] == 4 + return cond(pred, true_fn, false_fn, [x]) + + example_input = torch.ones(5, 5) + example_input_func = to_fun(example_input) + + def f_wrapper(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + torch._enable_functionalization(reapply_views=False) + try: + func_args = pytree.tree_map(to_fun, args) + func_kwargs = pytree.tree_map(to_fun, kwargs) + return pytree.tree_map(from_fun, func(*args, **kwargs)) + finally: + torch._disable_functionalization() + return wrapper + + result_gm = make_fx(f_wrapper(f))(example_input) + for node in result_gm.true_graph_0.graph.nodes: + if node.op == "call_function": + self.assertTrue(not node.target._schema.is_mutable) + + for node in result_gm.false_graph_0.graph.nodes: + if node.op == "call_function": + self.assertTrue(not node.target._schema.is_mutable) + + self.assertEqual(result_gm(torch.ones(5, 5)), f(torch.ones(5, 5))) + def test_cond_nested_traced_other_inputs(self): def true_nested(y): return y * y @@ -641,6 +763,35 @@ class TestControlFlowTraced(TestCase): if node.op == "call_function": self.assertTrue(not node.target._schema.is_mutable) + def test_map_functionalized_aot_func(self): + def map_fn(x, y): + z = x + y + z.add_(4) + return z + + def f(xs, y): + return control_flow.map(map_fn, xs, y) + + def f_wrapper(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + torch._enable_functionalization(reapply_views=False) + try: + return pytree.tree_map(from_fun, func(*args, **kwargs)) + finally: + torch._disable_functionalization() + return wrapper + + example_inputs = (torch.ones(3, 2, 4), torch.ones(4)) + + gm = make_fx(f_wrapper(f))(*example_inputs) + + for node in gm.body_graph_0.graph.nodes: + if node.op == "call_function": + self.assertTrue(not node.target._schema.is_mutable) + + self.assertEqual(gm(*example_inputs), f(*example_inputs)) + def test_map_functionalized_arg_mutation(self): def map_fn(x, y): y.add_(4) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 004472ca2a1..328662d011f 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -703,6 +703,7 @@ def export( tracing_mode: str = "symbolic", constraints: List[Constraint] = None, assume_static_by_default: bool = False, + functionalize: bool = False, **kwargs, ) -> Tuple[torch.fx.GraphModule, Set[_guards.Guard]]: """ @@ -727,6 +728,9 @@ def export( tracing_mode (str): If "symbolic", turn on dynamic shapes support. Default is "symbolic". + functionalize (bool): If True, the resulting aten graph module will be functional. You will need to + set aten_graph=True to see the effect. By default, this flag will be false. + **kwargs: Arbitrary keyword arguments to be passed to the function f. Returns: @@ -752,6 +756,12 @@ def export( assert aten_graph, "pre_autograd=True can only be used when aten_graph=True" f = innermost_fn(f) + if functionalize and not aten_graph: + raise UserError( + UserErrorType.ANTI_PATTERN, + "TorchDynamo won't functionalize non-aten graphs. Please set `functionalize` to true", + ) + graph = None out_guards = None graph_captured_input = None @@ -906,10 +916,40 @@ def export( example_fake_inputs = [fake_mode.from_tensor(t) for t in example_inputs] if aten_graph: + memo: Dict[torch.Tensor, torch.Tensor] = {} + + def to_fun(t): + if isinstance(t, torch.Tensor): + if t in memo: + return memo[t] + r = torch._to_functional_tensor(t, mirror_autograd_meta=True) + memo[t] = r + return r + else: + return t + + def from_fun(t): + if not isinstance(t, torch.Tensor) or not torch._is_functional_tensor(t): + return t + torch._sync(t) + return torch._from_functional_tensor(t) + # Running graph with interpreter is needed for propagating the stack_trace def graph_with_interpreter(*args): with torch.fx.traceback.preserve_node_meta(): - return torch.fx.Interpreter(graph).run(*args) + if functionalize: + torch._enable_functionalization(reapply_views=True) + try: + return pytree.tree_map( + from_fun, + torch.fx.Interpreter(graph).run( + *pytree.tree_map(to_fun, args) + ), + ) + finally: + torch._disable_functionalization() + else: + return torch.fx.Interpreter(graph).run(*args) with enable_python_dispatcher(), fake_mode: try: @@ -1033,9 +1073,7 @@ def export( ) new_graph.recompile() - # TODO remove this once Executorch uses proper functionalization - new_graph._example_fake_inputs = example_fake_inputs new_graph._matched_input_elements_positions = matched_input_elements_positions return (new_graph, out_guards)