mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add functionalization pass in TorchDynamo (#99461)
Fixes: https://github.com/pytorch/pytorch/issues/99000 Differential Revision: [D45106409](https://our.internmc.facebook.com/intern/diff/D45106409) Pull Request resolved: https://github.com/pytorch/pytorch/pull/99461 Approved by: https://github.com/bdhirsh, https://github.com/anijain2305, https://github.com/zou3519
This commit is contained in:
parent
31fdd19b5b
commit
bf08b072a7
5 changed files with 389 additions and 29 deletions
|
|
@ -36,7 +36,6 @@ In order to do this, we need implementations for each of the dispatch keys.
|
|||
"""
|
||||
cond = HigherOrderOperator("cond")
|
||||
|
||||
|
||||
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||
assert isinstance(operands, (list, tuple)), "Cond operands must be a list or tuple of tensors"
|
||||
assert all(isinstance(o, torch.Tensor) for o in operands), "Cond operands must be a list of tensors"
|
||||
|
|
@ -159,7 +158,6 @@ def cond_fake_tensor_mode(pred, true_fn, false_fn, operands):
|
|||
return true_outs
|
||||
|
||||
|
||||
|
||||
def _has_potential_branch_input_mutation(branch, inputs):
|
||||
"""
|
||||
Dispatch-trace the branch with inputs and check if
|
||||
|
|
@ -175,18 +173,27 @@ def _has_potential_branch_input_mutation(branch, inputs):
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
input_nodes = set()
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
input_nodes.add(node)
|
||||
if node.op == "call_function":
|
||||
target = node.target
|
||||
if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable:
|
||||
for arg in node.args:
|
||||
if arg in input_nodes:
|
||||
return True
|
||||
def _detect_input_mutation(gm):
|
||||
input_nodes = set()
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
input_nodes.add(node)
|
||||
if node.op == "call_function":
|
||||
target = node.target
|
||||
if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable:
|
||||
for arg in node.args:
|
||||
if arg in input_nodes:
|
||||
return True
|
||||
|
||||
for _, module in gm.named_children():
|
||||
if isinstance(module, torch.fx.GraphModule):
|
||||
if _detect_input_mutation(module):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return _detect_input_mutation(gm)
|
||||
|
||||
return False
|
||||
|
||||
def _has_potential_branch_input_alias(branch, inputs):
|
||||
"""
|
||||
|
|
@ -204,18 +211,50 @@ def _has_potential_branch_input_alias(branch, inputs):
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
input_storages = set()
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
input_storages.add(StorageWeakRef(node.meta['val']._typed_storage()))
|
||||
if node.op == "output":
|
||||
for out in node.args:
|
||||
out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
|
||||
if out_storage in input_storages:
|
||||
return True
|
||||
def _detect_input_alias(gm):
|
||||
input_storages = set()
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
input_storages.add(StorageWeakRef(node.meta['val']._typed_storage()))
|
||||
if node.op == "output":
|
||||
for out in node.args:
|
||||
out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
|
||||
if out_storage in input_storages:
|
||||
return True
|
||||
|
||||
return False
|
||||
for _, module in gm.named_children():
|
||||
if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return _detect_input_alias(gm)
|
||||
|
||||
|
||||
@cond.py_impl(DispatchKey.Functionalize)
|
||||
def cond_func(pred, true_fn, false_fn, inputs):
|
||||
reapply_views = torch._C._functionalization_reapply_views_tls()
|
||||
unwrapped_inputs = _unwrap_all_tensors_from_functional(inputs, reapply_views=reapply_views)
|
||||
unwrapped_pred = _unwrap_all_tensors_from_functional(pred, reapply_views=reapply_views)
|
||||
mode = 'mutations_and_views' if reapply_views else 'mutations'
|
||||
guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize))
|
||||
try:
|
||||
functional_true = functionalize(true_fn, remove=mode)
|
||||
functional_false = functionalize(false_fn, remove=mode)
|
||||
for branch in [true_fn, false_fn]:
|
||||
if _has_potential_branch_input_mutation(branch, unwrapped_inputs):
|
||||
raise UnsupportedAliasMutationException("One of torch.cond branch "
|
||||
"might be modifying the input!")
|
||||
|
||||
if _has_potential_branch_input_alias(branch, unwrapped_inputs):
|
||||
raise UnsupportedAliasMutationException("One of torch.cond branch "
|
||||
"might be aliasing the input!")
|
||||
|
||||
cond_return = cond(unwrapped_pred, functional_true, functional_false, unwrapped_inputs)
|
||||
return _wrap_all_tensors_to_functional(cond_return, level=0)
|
||||
|
||||
finally:
|
||||
del guard
|
||||
|
||||
|
||||
@cond.py_impl(torch._C._functorch.TransformType.Functionalize)
|
||||
|
|
|
|||
|
|
@ -91,6 +91,34 @@ def map_fake_tensor_mode(f, xs, *args):
|
|||
outs = [f(x, *args) for x in xs]
|
||||
return outs[0].new_empty([xs.shape[0], *outs[0].shape])
|
||||
|
||||
|
||||
@map.py_impl(DispatchKey.Functionalize)
|
||||
def map_func(f, xs, *args):
|
||||
reapply_views = torch._C._functionalization_reapply_views_tls()
|
||||
unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views)
|
||||
unwrapped_args = _unwrap_all_tensors_from_functional(args, reapply_views=reapply_views)
|
||||
mode = 'mutations_and_views' if reapply_views else 'mutations'
|
||||
|
||||
guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize))
|
||||
try:
|
||||
functional_map_fn = functionalize(f, remove=mode)
|
||||
inputs = (unwrapped_xs,) + unwrapped_args
|
||||
|
||||
if _has_potential_branch_input_mutation(f, inputs):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"torch.map is mutating the input!"
|
||||
)
|
||||
|
||||
if _has_potential_branch_input_alias(f, inputs):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"torch.map is aliasing the input!"
|
||||
)
|
||||
|
||||
map_return = map(functional_map_fn, unwrapped_xs, *unwrapped_args)
|
||||
return _wrap_all_tensors_to_functional(map_return, level=0)
|
||||
finally:
|
||||
del guard
|
||||
|
||||
@map.py_impl(torch._C._functorch.TransformType.Functionalize)
|
||||
def map_functionalize(interpreter, f, xs, *args):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -12,8 +12,6 @@ from enum import Enum
|
|||
from typing import Dict, List, Sequence
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._dynamo.test_case
|
||||
|
|
@ -2335,7 +2333,6 @@ def forward(self, x):
|
|||
preserved = True
|
||||
self.assertTrue(preserved)
|
||||
|
||||
@pytest.mark.xfail(reason="Saving example_fake_inputs breaks the serialization")
|
||||
@config.patch(
|
||||
dynamic_shapes=True,
|
||||
capture_dynamic_output_shape_ops=True,
|
||||
|
|
@ -2655,6 +2652,113 @@ def forward(self, x):
|
|||
|
||||
self.assertEqual(f_correct(torch.ones(6, 4)), gm(torch.ones(6, 4)))
|
||||
|
||||
@config.patch(dynamic_shapes=True)
|
||||
def test_functionalize(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("buffer1", torch.ones(6, 2))
|
||||
|
||||
def forward(self, x):
|
||||
x.add_(2)
|
||||
return x.sum() + self.buffer1.sum()
|
||||
|
||||
example_inputs = (torch.ones(1, 2, 3),)
|
||||
gm, _ = torch._dynamo.export(
|
||||
Foo(),
|
||||
*example_inputs,
|
||||
aten_graph=True,
|
||||
tracing_mode="symbolic",
|
||||
functionalize=True,
|
||||
)
|
||||
|
||||
count = 0
|
||||
for node in gm.graph.nodes:
|
||||
if node.target == torch.ops.aten.add_.Tensor:
|
||||
count += 1
|
||||
self.assertEqual(count, 0)
|
||||
test_inp = (torch.ones(1, 2, 3),)
|
||||
test_inp_v2 = (torch.ones(1, 2, 3),)
|
||||
self.assertEqual(gm(*test_inp), Foo()(*test_inp_v2))
|
||||
|
||||
@config.patch(dynamic_shapes=True)
|
||||
def test_not_functionalize(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("buffer1", torch.ones(6, 2))
|
||||
|
||||
def forward(self, x):
|
||||
x.add_(2)
|
||||
return x.sum() + self.buffer1.sum()
|
||||
|
||||
example_inputs = (torch.ones(1, 2, 3),)
|
||||
gm, _ = torch._dynamo.export(
|
||||
Foo(),
|
||||
*example_inputs,
|
||||
aten_graph=True,
|
||||
tracing_mode="symbolic",
|
||||
functionalize=False,
|
||||
)
|
||||
count = 0
|
||||
for node in gm.graph.nodes:
|
||||
if node.target == torch.ops.aten.add_.Tensor:
|
||||
count += 1
|
||||
self.assertEqual(count, 1)
|
||||
test_inp = (torch.ones(1, 2, 3),)
|
||||
test_inp_v2 = (torch.ones(1, 2, 3),)
|
||||
self.assertEqual(gm(*test_inp), Foo()(*test_inp_v2))
|
||||
|
||||
@config.patch(dynamic_shapes=True, assume_static_by_default=False)
|
||||
def test_functionalize_cond(self):
|
||||
def foo(x):
|
||||
def true_true_fn(x):
|
||||
return x.sum() + 6
|
||||
|
||||
def true_false_fn(x):
|
||||
return x.sum() + 9
|
||||
|
||||
def true_fn(x):
|
||||
return cond(x.shape[0] > 6, true_true_fn, true_false_fn, [x])
|
||||
|
||||
def false_fn(x):
|
||||
return x.sum() - 1
|
||||
|
||||
return cond(x.shape[0] > 5, true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.ones(5, 2, 3),)
|
||||
gm, _ = torch._dynamo.export(
|
||||
foo,
|
||||
*example_inputs,
|
||||
aten_graph=True,
|
||||
tracing_mode="symbolic",
|
||||
functionalize=True,
|
||||
)
|
||||
self.assertEqual(gm(torch.ones(7, 2, 3)), foo(torch.ones(7, 2, 3)))
|
||||
|
||||
@config.patch(dynamic_shapes=True)
|
||||
def test_functionalize_simple(self):
|
||||
def foo(x):
|
||||
def true_fn(x):
|
||||
return x.sum() + 1
|
||||
|
||||
def false_fn(x):
|
||||
return x.sum() - 1
|
||||
|
||||
return cond(x.shape[0] > 5, true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.ones(5, 2, 3),)
|
||||
gm, _ = torch._dynamo.export(
|
||||
foo,
|
||||
*example_inputs,
|
||||
aten_graph=True,
|
||||
tracing_mode="symbolic",
|
||||
functionalize=True,
|
||||
)
|
||||
self.assertEqual(gm.true_graph_0(torch.ones(6, 4)), torch.ones(6, 4).sum() + 1)
|
||||
self.assertEqual(gm.false_graph_0(torch.ones(6, 4)), torch.ones(6, 4).sum() - 1)
|
||||
|
||||
@config.patch(dynamic_shapes=True)
|
||||
def test_round_dynamic_shapes(self):
|
||||
def f(x):
|
||||
return x[: round(x.shape[0] / 2)]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
# Owner(s): ["module: functorch"]
|
||||
import functools
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._functorch.aot_autograd import from_fun, to_fun
|
||||
from functorch.experimental import control_flow
|
||||
from functorch.experimental.control_flow import cond
|
||||
from functorch.experimental.control_flow import UnsupportedAliasMutationException
|
||||
|
|
@ -298,6 +301,125 @@ class TestControlFlowTraced(TestCase):
|
|||
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
|
||||
make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||
|
||||
def test_cond_functionalized_nested_input_mutation_with_aot_func(self):
|
||||
def true_true_fn(x):
|
||||
x.add_(4)
|
||||
return x.sin().max()
|
||||
|
||||
def true_false_fn(x):
|
||||
return x.cos().min()
|
||||
|
||||
def true_fn(x):
|
||||
pred = x.shape[0] == 1
|
||||
return cond(pred, true_true_fn, true_false_fn, [x])
|
||||
|
||||
def false_fn(x):
|
||||
return x.sum()
|
||||
|
||||
def f(x):
|
||||
pred = x.shape[0] == 1
|
||||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
example_input = torch.ones(4, 5)
|
||||
example_input_func = to_fun(example_input)
|
||||
torch._enable_functionalization(reapply_views=False)
|
||||
try:
|
||||
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
|
||||
f(example_input_func)
|
||||
|
||||
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
|
||||
make_fx(f)(example_input_func)
|
||||
finally:
|
||||
torch._disable_functionalization()
|
||||
|
||||
def f_wrapper(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
torch._enable_functionalization(reapply_views=False)
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
torch._disable_functionalization()
|
||||
return wrapper
|
||||
|
||||
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
|
||||
make_fx(f_wrapper(f))(example_input_func)
|
||||
|
||||
def test_cond_functionalized_input_aliasing_with_aot_func(self):
|
||||
def true_fn(x):
|
||||
return x
|
||||
|
||||
def false_fn(x):
|
||||
view_x = x.view(x.shape)
|
||||
return view_x
|
||||
|
||||
def f(x):
|
||||
pred = x.shape[0] == 4
|
||||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
example_input = torch.ones(5, 5)
|
||||
example_input_func = to_fun(example_input)
|
||||
torch._enable_functionalization(reapply_views=False)
|
||||
try:
|
||||
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch might be aliasing"):
|
||||
f(example_input_func)
|
||||
finally:
|
||||
torch._disable_functionalization()
|
||||
|
||||
def f_wrapper(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
torch._enable_functionalization(reapply_views=False)
|
||||
try:
|
||||
func_args = pytree.tree_map(to_fun, args)
|
||||
func_kwargs = pytree.tree_map(to_fun, kwargs)
|
||||
return func(*func_args, **func_kwargs)
|
||||
finally:
|
||||
torch._disable_functionalization()
|
||||
return wrapper
|
||||
|
||||
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch might be aliasing"):
|
||||
make_fx(f_wrapper(f))(example_input)
|
||||
|
||||
def test_cond_functionalized_aot_func_check_functional(self):
|
||||
def true_fn(x):
|
||||
return x.cos()
|
||||
|
||||
def false_fn(x):
|
||||
y = x.sin()
|
||||
y.add_(5)
|
||||
return y
|
||||
|
||||
def f(x):
|
||||
pred = x.shape[0] == 4
|
||||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
example_input = torch.ones(5, 5)
|
||||
example_input_func = to_fun(example_input)
|
||||
|
||||
def f_wrapper(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
torch._enable_functionalization(reapply_views=False)
|
||||
try:
|
||||
func_args = pytree.tree_map(to_fun, args)
|
||||
func_kwargs = pytree.tree_map(to_fun, kwargs)
|
||||
return pytree.tree_map(from_fun, func(*args, **kwargs))
|
||||
finally:
|
||||
torch._disable_functionalization()
|
||||
return wrapper
|
||||
|
||||
result_gm = make_fx(f_wrapper(f))(example_input)
|
||||
for node in result_gm.true_graph_0.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
self.assertTrue(not node.target._schema.is_mutable)
|
||||
|
||||
for node in result_gm.false_graph_0.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
self.assertTrue(not node.target._schema.is_mutable)
|
||||
|
||||
self.assertEqual(result_gm(torch.ones(5, 5)), f(torch.ones(5, 5)))
|
||||
|
||||
def test_cond_nested_traced_other_inputs(self):
|
||||
def true_nested(y):
|
||||
return y * y
|
||||
|
|
@ -641,6 +763,35 @@ class TestControlFlowTraced(TestCase):
|
|||
if node.op == "call_function":
|
||||
self.assertTrue(not node.target._schema.is_mutable)
|
||||
|
||||
def test_map_functionalized_aot_func(self):
|
||||
def map_fn(x, y):
|
||||
z = x + y
|
||||
z.add_(4)
|
||||
return z
|
||||
|
||||
def f(xs, y):
|
||||
return control_flow.map(map_fn, xs, y)
|
||||
|
||||
def f_wrapper(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
torch._enable_functionalization(reapply_views=False)
|
||||
try:
|
||||
return pytree.tree_map(from_fun, func(*args, **kwargs))
|
||||
finally:
|
||||
torch._disable_functionalization()
|
||||
return wrapper
|
||||
|
||||
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
|
||||
|
||||
gm = make_fx(f_wrapper(f))(*example_inputs)
|
||||
|
||||
for node in gm.body_graph_0.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
self.assertTrue(not node.target._schema.is_mutable)
|
||||
|
||||
self.assertEqual(gm(*example_inputs), f(*example_inputs))
|
||||
|
||||
def test_map_functionalized_arg_mutation(self):
|
||||
def map_fn(x, y):
|
||||
y.add_(4)
|
||||
|
|
|
|||
|
|
@ -703,6 +703,7 @@ def export(
|
|||
tracing_mode: str = "symbolic",
|
||||
constraints: List[Constraint] = None,
|
||||
assume_static_by_default: bool = False,
|
||||
functionalize: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.fx.GraphModule, Set[_guards.Guard]]:
|
||||
"""
|
||||
|
|
@ -727,6 +728,9 @@ def export(
|
|||
|
||||
tracing_mode (str): If "symbolic", turn on dynamic shapes support. Default is "symbolic".
|
||||
|
||||
functionalize (bool): If True, the resulting aten graph module will be functional. You will need to
|
||||
set aten_graph=True to see the effect. By default, this flag will be false.
|
||||
|
||||
**kwargs: Arbitrary keyword arguments to be passed to the function f.
|
||||
|
||||
Returns:
|
||||
|
|
@ -752,6 +756,12 @@ def export(
|
|||
assert aten_graph, "pre_autograd=True can only be used when aten_graph=True"
|
||||
f = innermost_fn(f)
|
||||
|
||||
if functionalize and not aten_graph:
|
||||
raise UserError(
|
||||
UserErrorType.ANTI_PATTERN,
|
||||
"TorchDynamo won't functionalize non-aten graphs. Please set `functionalize` to true",
|
||||
)
|
||||
|
||||
graph = None
|
||||
out_guards = None
|
||||
graph_captured_input = None
|
||||
|
|
@ -906,10 +916,40 @@ def export(
|
|||
example_fake_inputs = [fake_mode.from_tensor(t) for t in example_inputs]
|
||||
|
||||
if aten_graph:
|
||||
memo: Dict[torch.Tensor, torch.Tensor] = {}
|
||||
|
||||
def to_fun(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
if t in memo:
|
||||
return memo[t]
|
||||
r = torch._to_functional_tensor(t, mirror_autograd_meta=True)
|
||||
memo[t] = r
|
||||
return r
|
||||
else:
|
||||
return t
|
||||
|
||||
def from_fun(t):
|
||||
if not isinstance(t, torch.Tensor) or not torch._is_functional_tensor(t):
|
||||
return t
|
||||
torch._sync(t)
|
||||
return torch._from_functional_tensor(t)
|
||||
|
||||
# Running graph with interpreter is needed for propagating the stack_trace
|
||||
def graph_with_interpreter(*args):
|
||||
with torch.fx.traceback.preserve_node_meta():
|
||||
return torch.fx.Interpreter(graph).run(*args)
|
||||
if functionalize:
|
||||
torch._enable_functionalization(reapply_views=True)
|
||||
try:
|
||||
return pytree.tree_map(
|
||||
from_fun,
|
||||
torch.fx.Interpreter(graph).run(
|
||||
*pytree.tree_map(to_fun, args)
|
||||
),
|
||||
)
|
||||
finally:
|
||||
torch._disable_functionalization()
|
||||
else:
|
||||
return torch.fx.Interpreter(graph).run(*args)
|
||||
|
||||
with enable_python_dispatcher(), fake_mode:
|
||||
try:
|
||||
|
|
@ -1033,9 +1073,7 @@ def export(
|
|||
)
|
||||
|
||||
new_graph.recompile()
|
||||
|
||||
# TODO remove this once Executorch uses proper functionalization
|
||||
new_graph._example_fake_inputs = example_fake_inputs
|
||||
new_graph._matched_input_elements_positions = matched_input_elements_positions
|
||||
|
||||
return (new_graph, out_guards)
|
||||
|
|
|
|||
Loading…
Reference in a new issue