Tugsbayasgalan Manlaibaatar 2023-05-04 21:06:41 -07:00 committed by PyTorch MergeBot
parent 31fdd19b5b
commit bf08b072a7
5 changed files with 389 additions and 29 deletions

View file

@ -36,7 +36,6 @@ In order to do this, we need implementations for each of the dispatch keys.
"""
cond = HigherOrderOperator("cond")
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
assert isinstance(operands, (list, tuple)), "Cond operands must be a list or tuple of tensors"
assert all(isinstance(o, torch.Tensor) for o in operands), "Cond operands must be a list of tensors"
@ -159,7 +158,6 @@ def cond_fake_tensor_mode(pred, true_fn, false_fn, operands):
return true_outs
def _has_potential_branch_input_mutation(branch, inputs):
"""
Dispatch-trace the branch with inputs and check if
@ -175,18 +173,27 @@ def _has_potential_branch_input_mutation(branch, inputs):
except Exception as e:
raise e
input_nodes = set()
for node in gm.graph.nodes:
if node.op == "placeholder":
input_nodes.add(node)
if node.op == "call_function":
target = node.target
if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable:
for arg in node.args:
if arg in input_nodes:
return True
def _detect_input_mutation(gm):
input_nodes = set()
for node in gm.graph.nodes:
if node.op == "placeholder":
input_nodes.add(node)
if node.op == "call_function":
target = node.target
if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable:
for arg in node.args:
if arg in input_nodes:
return True
for _, module in gm.named_children():
if isinstance(module, torch.fx.GraphModule):
if _detect_input_mutation(module):
return True
return False
return _detect_input_mutation(gm)
return False
def _has_potential_branch_input_alias(branch, inputs):
"""
@ -204,18 +211,50 @@ def _has_potential_branch_input_alias(branch, inputs):
except Exception as e:
raise e
input_storages = set()
for node in gm.graph.nodes:
if node.op == "placeholder":
input_storages.add(StorageWeakRef(node.meta['val']._typed_storage()))
if node.op == "output":
for out in node.args:
out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
if out_storage in input_storages:
return True
def _detect_input_alias(gm):
input_storages = set()
for node in gm.graph.nodes:
if node.op == "placeholder":
input_storages.add(StorageWeakRef(node.meta['val']._typed_storage()))
if node.op == "output":
for out in node.args:
out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
if out_storage in input_storages:
return True
return False
for _, module in gm.named_children():
if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module):
return True
return False
return _detect_input_alias(gm)
@cond.py_impl(DispatchKey.Functionalize)
def cond_func(pred, true_fn, false_fn, inputs):
reapply_views = torch._C._functionalization_reapply_views_tls()
unwrapped_inputs = _unwrap_all_tensors_from_functional(inputs, reapply_views=reapply_views)
unwrapped_pred = _unwrap_all_tensors_from_functional(pred, reapply_views=reapply_views)
mode = 'mutations_and_views' if reapply_views else 'mutations'
guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize))
try:
functional_true = functionalize(true_fn, remove=mode)
functional_false = functionalize(false_fn, remove=mode)
for branch in [true_fn, false_fn]:
if _has_potential_branch_input_mutation(branch, unwrapped_inputs):
raise UnsupportedAliasMutationException("One of torch.cond branch "
"might be modifying the input!")
if _has_potential_branch_input_alias(branch, unwrapped_inputs):
raise UnsupportedAliasMutationException("One of torch.cond branch "
"might be aliasing the input!")
cond_return = cond(unwrapped_pred, functional_true, functional_false, unwrapped_inputs)
return _wrap_all_tensors_to_functional(cond_return, level=0)
finally:
del guard
@cond.py_impl(torch._C._functorch.TransformType.Functionalize)

View file

@ -91,6 +91,34 @@ def map_fake_tensor_mode(f, xs, *args):
outs = [f(x, *args) for x in xs]
return outs[0].new_empty([xs.shape[0], *outs[0].shape])
@map.py_impl(DispatchKey.Functionalize)
def map_func(f, xs, *args):
reapply_views = torch._C._functionalization_reapply_views_tls()
unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views)
unwrapped_args = _unwrap_all_tensors_from_functional(args, reapply_views=reapply_views)
mode = 'mutations_and_views' if reapply_views else 'mutations'
guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize))
try:
functional_map_fn = functionalize(f, remove=mode)
inputs = (unwrapped_xs,) + unwrapped_args
if _has_potential_branch_input_mutation(f, inputs):
raise UnsupportedAliasMutationException(
"torch.map is mutating the input!"
)
if _has_potential_branch_input_alias(f, inputs):
raise UnsupportedAliasMutationException(
"torch.map is aliasing the input!"
)
map_return = map(functional_map_fn, unwrapped_xs, *unwrapped_args)
return _wrap_all_tensors_to_functional(map_return, level=0)
finally:
del guard
@map.py_impl(torch._C._functorch.TransformType.Functionalize)
def map_functionalize(interpreter, f, xs, *args):
"""

View file

@ -12,8 +12,6 @@ from enum import Enum
from typing import Dict, List, Sequence
from unittest.mock import patch
import pytest
import torch
import torch._dynamo
import torch._dynamo.test_case
@ -2335,7 +2333,6 @@ def forward(self, x):
preserved = True
self.assertTrue(preserved)
@pytest.mark.xfail(reason="Saving example_fake_inputs breaks the serialization")
@config.patch(
dynamic_shapes=True,
capture_dynamic_output_shape_ops=True,
@ -2655,6 +2652,113 @@ def forward(self, x):
self.assertEqual(f_correct(torch.ones(6, 4)), gm(torch.ones(6, 4)))
@config.patch(dynamic_shapes=True)
def test_functionalize(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.ones(6, 2))
def forward(self, x):
x.add_(2)
return x.sum() + self.buffer1.sum()
example_inputs = (torch.ones(1, 2, 3),)
gm, _ = torch._dynamo.export(
Foo(),
*example_inputs,
aten_graph=True,
tracing_mode="symbolic",
functionalize=True,
)
count = 0
for node in gm.graph.nodes:
if node.target == torch.ops.aten.add_.Tensor:
count += 1
self.assertEqual(count, 0)
test_inp = (torch.ones(1, 2, 3),)
test_inp_v2 = (torch.ones(1, 2, 3),)
self.assertEqual(gm(*test_inp), Foo()(*test_inp_v2))
@config.patch(dynamic_shapes=True)
def test_not_functionalize(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.ones(6, 2))
def forward(self, x):
x.add_(2)
return x.sum() + self.buffer1.sum()
example_inputs = (torch.ones(1, 2, 3),)
gm, _ = torch._dynamo.export(
Foo(),
*example_inputs,
aten_graph=True,
tracing_mode="symbolic",
functionalize=False,
)
count = 0
for node in gm.graph.nodes:
if node.target == torch.ops.aten.add_.Tensor:
count += 1
self.assertEqual(count, 1)
test_inp = (torch.ones(1, 2, 3),)
test_inp_v2 = (torch.ones(1, 2, 3),)
self.assertEqual(gm(*test_inp), Foo()(*test_inp_v2))
@config.patch(dynamic_shapes=True, assume_static_by_default=False)
def test_functionalize_cond(self):
def foo(x):
def true_true_fn(x):
return x.sum() + 6
def true_false_fn(x):
return x.sum() + 9
def true_fn(x):
return cond(x.shape[0] > 6, true_true_fn, true_false_fn, [x])
def false_fn(x):
return x.sum() - 1
return cond(x.shape[0] > 5, true_fn, false_fn, [x])
example_inputs = (torch.ones(5, 2, 3),)
gm, _ = torch._dynamo.export(
foo,
*example_inputs,
aten_graph=True,
tracing_mode="symbolic",
functionalize=True,
)
self.assertEqual(gm(torch.ones(7, 2, 3)), foo(torch.ones(7, 2, 3)))
@config.patch(dynamic_shapes=True)
def test_functionalize_simple(self):
def foo(x):
def true_fn(x):
return x.sum() + 1
def false_fn(x):
return x.sum() - 1
return cond(x.shape[0] > 5, true_fn, false_fn, [x])
example_inputs = (torch.ones(5, 2, 3),)
gm, _ = torch._dynamo.export(
foo,
*example_inputs,
aten_graph=True,
tracing_mode="symbolic",
functionalize=True,
)
self.assertEqual(gm.true_graph_0(torch.ones(6, 4)), torch.ones(6, 4).sum() + 1)
self.assertEqual(gm.false_graph_0(torch.ones(6, 4)), torch.ones(6, 4).sum() - 1)
@config.patch(dynamic_shapes=True)
def test_round_dynamic_shapes(self):
def f(x):
return x[: round(x.shape[0] / 2)]

View file

@ -1,7 +1,10 @@
# Owner(s): ["module: functorch"]
import functools
import unittest
import torch
import torch.utils._pytree as pytree
from torch._functorch.aot_autograd import from_fun, to_fun
from functorch.experimental import control_flow
from functorch.experimental.control_flow import cond
from functorch.experimental.control_flow import UnsupportedAliasMutationException
@ -298,6 +301,125 @@ class TestControlFlowTraced(TestCase):
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
make_fx(torch.func.functionalize(f))(*example_inputs)
def test_cond_functionalized_nested_input_mutation_with_aot_func(self):
def true_true_fn(x):
x.add_(4)
return x.sin().max()
def true_false_fn(x):
return x.cos().min()
def true_fn(x):
pred = x.shape[0] == 1
return cond(pred, true_true_fn, true_false_fn, [x])
def false_fn(x):
return x.sum()
def f(x):
pred = x.shape[0] == 1
return cond(pred, true_fn, false_fn, [x])
example_input = torch.ones(4, 5)
example_input_func = to_fun(example_input)
torch._enable_functionalization(reapply_views=False)
try:
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
f(example_input_func)
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
make_fx(f)(example_input_func)
finally:
torch._disable_functionalization()
def f_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
torch._enable_functionalization(reapply_views=False)
try:
return func(*args, **kwargs)
finally:
torch._disable_functionalization()
return wrapper
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
make_fx(f_wrapper(f))(example_input_func)
def test_cond_functionalized_input_aliasing_with_aot_func(self):
def true_fn(x):
return x
def false_fn(x):
view_x = x.view(x.shape)
return view_x
def f(x):
pred = x.shape[0] == 4
return cond(pred, true_fn, false_fn, [x])
example_input = torch.ones(5, 5)
example_input_func = to_fun(example_input)
torch._enable_functionalization(reapply_views=False)
try:
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch might be aliasing"):
f(example_input_func)
finally:
torch._disable_functionalization()
def f_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
torch._enable_functionalization(reapply_views=False)
try:
func_args = pytree.tree_map(to_fun, args)
func_kwargs = pytree.tree_map(to_fun, kwargs)
return func(*func_args, **func_kwargs)
finally:
torch._disable_functionalization()
return wrapper
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch might be aliasing"):
make_fx(f_wrapper(f))(example_input)
def test_cond_functionalized_aot_func_check_functional(self):
def true_fn(x):
return x.cos()
def false_fn(x):
y = x.sin()
y.add_(5)
return y
def f(x):
pred = x.shape[0] == 4
return cond(pred, true_fn, false_fn, [x])
example_input = torch.ones(5, 5)
example_input_func = to_fun(example_input)
def f_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
torch._enable_functionalization(reapply_views=False)
try:
func_args = pytree.tree_map(to_fun, args)
func_kwargs = pytree.tree_map(to_fun, kwargs)
return pytree.tree_map(from_fun, func(*args, **kwargs))
finally:
torch._disable_functionalization()
return wrapper
result_gm = make_fx(f_wrapper(f))(example_input)
for node in result_gm.true_graph_0.graph.nodes:
if node.op == "call_function":
self.assertTrue(not node.target._schema.is_mutable)
for node in result_gm.false_graph_0.graph.nodes:
if node.op == "call_function":
self.assertTrue(not node.target._schema.is_mutable)
self.assertEqual(result_gm(torch.ones(5, 5)), f(torch.ones(5, 5)))
def test_cond_nested_traced_other_inputs(self):
def true_nested(y):
return y * y
@ -641,6 +763,35 @@ class TestControlFlowTraced(TestCase):
if node.op == "call_function":
self.assertTrue(not node.target._schema.is_mutable)
def test_map_functionalized_aot_func(self):
def map_fn(x, y):
z = x + y
z.add_(4)
return z
def f(xs, y):
return control_flow.map(map_fn, xs, y)
def f_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
torch._enable_functionalization(reapply_views=False)
try:
return pytree.tree_map(from_fun, func(*args, **kwargs))
finally:
torch._disable_functionalization()
return wrapper
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
gm = make_fx(f_wrapper(f))(*example_inputs)
for node in gm.body_graph_0.graph.nodes:
if node.op == "call_function":
self.assertTrue(not node.target._schema.is_mutable)
self.assertEqual(gm(*example_inputs), f(*example_inputs))
def test_map_functionalized_arg_mutation(self):
def map_fn(x, y):
y.add_(4)

View file

@ -703,6 +703,7 @@ def export(
tracing_mode: str = "symbolic",
constraints: List[Constraint] = None,
assume_static_by_default: bool = False,
functionalize: bool = False,
**kwargs,
) -> Tuple[torch.fx.GraphModule, Set[_guards.Guard]]:
"""
@ -727,6 +728,9 @@ def export(
tracing_mode (str): If "symbolic", turn on dynamic shapes support. Default is "symbolic".
functionalize (bool): If True, the resulting aten graph module will be functional. You will need to
set aten_graph=True to see the effect. By default, this flag will be false.
**kwargs: Arbitrary keyword arguments to be passed to the function f.
Returns:
@ -752,6 +756,12 @@ def export(
assert aten_graph, "pre_autograd=True can only be used when aten_graph=True"
f = innermost_fn(f)
if functionalize and not aten_graph:
raise UserError(
UserErrorType.ANTI_PATTERN,
"TorchDynamo won't functionalize non-aten graphs. Please set `functionalize` to true",
)
graph = None
out_guards = None
graph_captured_input = None
@ -906,10 +916,40 @@ def export(
example_fake_inputs = [fake_mode.from_tensor(t) for t in example_inputs]
if aten_graph:
memo: Dict[torch.Tensor, torch.Tensor] = {}
def to_fun(t):
if isinstance(t, torch.Tensor):
if t in memo:
return memo[t]
r = torch._to_functional_tensor(t, mirror_autograd_meta=True)
memo[t] = r
return r
else:
return t
def from_fun(t):
if not isinstance(t, torch.Tensor) or not torch._is_functional_tensor(t):
return t
torch._sync(t)
return torch._from_functional_tensor(t)
# Running graph with interpreter is needed for propagating the stack_trace
def graph_with_interpreter(*args):
with torch.fx.traceback.preserve_node_meta():
return torch.fx.Interpreter(graph).run(*args)
if functionalize:
torch._enable_functionalization(reapply_views=True)
try:
return pytree.tree_map(
from_fun,
torch.fx.Interpreter(graph).run(
*pytree.tree_map(to_fun, args)
),
)
finally:
torch._disable_functionalization()
else:
return torch.fx.Interpreter(graph).run(*args)
with enable_python_dispatcher(), fake_mode:
try:
@ -1033,9 +1073,7 @@ def export(
)
new_graph.recompile()
# TODO remove this once Executorch uses proper functionalization
new_graph._example_fake_inputs = example_fake_inputs
new_graph._matched_input_elements_positions = matched_input_elements_positions
return (new_graph, out_guards)