mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[hierarchical-compilation][hop] Introduce invoke_subgraph (#137538)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137538 Approved by: https://github.com/zou3519
This commit is contained in:
parent
046f02d2de
commit
4dd4d38ca9
8 changed files with 760 additions and 5 deletions
334
test/higher_order_ops/test_invoke_subgraph.py
Normal file
334
test/higher_order_ops/test_invoke_subgraph.py
Normal file
|
|
@ -0,0 +1,334 @@
|
|||
# Owner(s): ["module: higher order operators"]
|
||||
# flake8: noqa: B950
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._functorch
|
||||
import torch._inductor
|
||||
import torch._inductor.decomposition
|
||||
from functorch.compile import aot_function, nop
|
||||
from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm
|
||||
from torch._higher_order_ops import invoke_subgraph
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
TEST_WITH_CROSSREF,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
@skipIfTorchDynamo("Not a torch._dynamo test")
|
||||
class TestInvokeSubgraph(TestCase):
|
||||
def test_simple(self):
|
||||
def gn(x, y):
|
||||
return (torch.mul(x, y),)
|
||||
|
||||
def fn(x, y):
|
||||
return invoke_subgraph(gn, None, (x, y))[0]
|
||||
|
||||
x = torch.randn(8, requires_grad=True)
|
||||
y = torch.randn(8, requires_grad=True)
|
||||
ref = gn(x, y)[0]
|
||||
|
||||
x_clone = x.clone().detach().requires_grad_(True)
|
||||
y_clone = y.clone().detach().requires_grad_(True)
|
||||
res = fn(x_clone, y_clone)
|
||||
|
||||
# Run backward
|
||||
ref.sum().backward()
|
||||
res.sum().backward()
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(x.grad, x_clone.grad)
|
||||
self.assertEqual(y.grad, y_clone.grad)
|
||||
|
||||
def test_aot_function(self):
|
||||
def gn(x, y):
|
||||
return (torch.mul(x, y),)
|
||||
|
||||
def fn(x, y):
|
||||
return invoke_subgraph(gn, None, (x, y))[0]
|
||||
|
||||
x = torch.randn(8, requires_grad=True)
|
||||
y = torch.randn(8, requires_grad=True)
|
||||
ref = gn(x, y)[0]
|
||||
|
||||
x_clone = x.clone().detach().requires_grad_(True)
|
||||
y_clone = y.clone().detach().requires_grad_(True)
|
||||
aot_fn = aot_function(fn, nop)
|
||||
res = aot_fn(x_clone, y_clone)
|
||||
|
||||
# Run backward
|
||||
ref.sum().backward()
|
||||
res.sum().backward()
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(x.grad, x_clone.grad)
|
||||
self.assertEqual(y.grad, y_clone.grad)
|
||||
|
||||
def test_multiple(self):
|
||||
n_layers = 2
|
||||
|
||||
def cos(x):
|
||||
return (torch.cos(x),)
|
||||
|
||||
def sin(x):
|
||||
return (torch.sin(x),)
|
||||
|
||||
def fn(x):
|
||||
a = invoke_subgraph(cos, None, (x,))[0]
|
||||
b = invoke_subgraph(sin, None, (a,))[0]
|
||||
return invoke_subgraph(cos, None, (b,))[0]
|
||||
|
||||
x = torch.randn(8, requires_grad=True)
|
||||
ref = fn(x)
|
||||
aot_fn = aot_function(fn, nop)
|
||||
res = aot_fn(x)
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_differing_strides_for_grad_outs(self):
|
||||
class CustomOp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return torch.sin(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
a = grad_out.view(12, 5)
|
||||
return torch.cos(torch.reshape(a, (3, 4, 5)))
|
||||
|
||||
def gn(x):
|
||||
return (CustomOp.apply(x),)
|
||||
|
||||
def fn(x):
|
||||
a = invoke_subgraph(gn, None, (x,))[0]
|
||||
# Force stride changes so that backward view causes a failure if
|
||||
# contiguous not called.
|
||||
b = torch.permute(a, (0, 2, 1))
|
||||
return b
|
||||
|
||||
x = torch.randn(3, 4, 5, requires_grad=True)
|
||||
ref = torch.permute(gn(x)[0], (0, 2, 1))
|
||||
|
||||
x_clone = x.clone().detach().requires_grad_(True)
|
||||
aot_fn = aot_function(fn, nop)
|
||||
res = aot_fn(x_clone)
|
||||
|
||||
# Run backward
|
||||
ref.sum().backward()
|
||||
res.sum().backward()
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(x.grad, x_clone.grad)
|
||||
|
||||
|
||||
@skipIfTorchDynamo("Not a torch._dynamo test")
|
||||
class TestInvokeSubgraphCompile(TestCase):
|
||||
def test_simple(self):
|
||||
def gn(x, y):
|
||||
return (torch.mul(x, y),)
|
||||
|
||||
def fn(x, y):
|
||||
return invoke_subgraph(gn, None, (x, y))[0]
|
||||
|
||||
x = torch.randn(8, requires_grad=True)
|
||||
y = torch.randn(8, requires_grad=True)
|
||||
ref = gn(x, y)[0]
|
||||
|
||||
x_clone = x.clone().detach().requires_grad_(True)
|
||||
y_clone = y.clone().detach().requires_grad_(True)
|
||||
res = torch.compile(fn, backend="eager", fullgraph=True)(x_clone, y_clone)
|
||||
|
||||
# Run backward
|
||||
ref.sum().backward()
|
||||
res.sum().backward()
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(x.grad, x_clone.grad)
|
||||
self.assertEqual(y.grad, y_clone.grad)
|
||||
|
||||
def test_multiple(self):
|
||||
def gn(x, y):
|
||||
return (torch.mul(x, y),)
|
||||
|
||||
def fn(x, y):
|
||||
a = invoke_subgraph(gn, None, (x, y))[0]
|
||||
return invoke_subgraph(gn, None, (a, y))[0]
|
||||
|
||||
x = torch.randn(8, requires_grad=True)
|
||||
y = torch.randn(8, requires_grad=True)
|
||||
ref = fn(x, y)
|
||||
|
||||
x_clone = x.clone().detach().requires_grad_(True)
|
||||
y_clone = y.clone().detach().requires_grad_(True)
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
res = torch.compile(fn, backend=backend, fullgraph=True)(x_clone, y_clone)
|
||||
|
||||
# Run backward
|
||||
ref.sum().backward()
|
||||
res.sum().backward()
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(x.grad, x_clone.grad)
|
||||
self.assertEqual(y.grad, y_clone.grad)
|
||||
|
||||
# Check that the Dynamo graph has just one subgraph module
|
||||
self.assertEqual(len(backend.graphs), 1)
|
||||
subgraph_attr_names = set()
|
||||
for node in backend.graphs[0].graph.nodes:
|
||||
if node.op == "get_attr":
|
||||
subgraph_attr_names.add(node.target)
|
||||
self.assertEqual(len(subgraph_attr_names), 1)
|
||||
|
||||
if not TEST_WITH_CROSSREF:
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[8]", L_y_: "f32[8]"):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
|
||||
invoke_subgraph_0 = self.invoke_subgraph_0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_, l_y_)); invoke_subgraph_0 = l_x_ = None
|
||||
a: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
|
||||
invoke_subgraph_1 = self.invoke_subgraph_0
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', (a, l_y_)); invoke_subgraph_1 = a = l_y_ = None
|
||||
getitem_1: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
return (getitem_1,)
|
||||
|
||||
class invoke_subgraph_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"):
|
||||
child: "f32[8]" = torch.mul(l_x_, l_y_); l_x_ = l_y_ = None
|
||||
return (child,)
|
||||
""",
|
||||
)
|
||||
|
||||
def test_nonlocal_update(self):
|
||||
counter = 2
|
||||
|
||||
def gn(x, y):
|
||||
nonlocal counter
|
||||
return (torch.mul(x, y) * counter,)
|
||||
|
||||
def fn(x, y):
|
||||
nonlocal counter
|
||||
counter = 2
|
||||
a = invoke_subgraph(gn, None, (x, y))[0]
|
||||
counter = 3
|
||||
return invoke_subgraph(gn, None, (a, y))[0]
|
||||
|
||||
x = torch.randn(8, requires_grad=True)
|
||||
y = torch.randn(8, requires_grad=True)
|
||||
ref = fn(x, y)
|
||||
|
||||
x_clone = x.clone().detach().requires_grad_(True)
|
||||
y_clone = y.clone().detach().requires_grad_(True)
|
||||
res = torch.compile(fn, backend="eager", fullgraph=True)(x_clone, y_clone)
|
||||
|
||||
# Run backward
|
||||
ref.sum().backward()
|
||||
res.sum().backward()
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(x.grad, x_clone.grad)
|
||||
self.assertEqual(y.grad, y_clone.grad)
|
||||
|
||||
torch._dynamo.reset()
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
torch.compile(fn, backend=backend, fullgraph=True)(x_clone, y_clone)
|
||||
|
||||
if not TEST_WITH_CROSSREF:
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[8]", L_y_: "f32[8]"):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
|
||||
invoke_subgraph_0 = self.invoke_subgraph_0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_, l_y_)); invoke_subgraph_0 = l_x_ = None
|
||||
a: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
|
||||
invoke_subgraph_1 = self.invoke_subgraph_1
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_1', (a, l_y_)); invoke_subgraph_1 = a = l_y_ = None
|
||||
getitem_1: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
return (getitem_1,)
|
||||
|
||||
class invoke_subgraph_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"):
|
||||
mul: "f32[8]" = torch.mul(l_x_, l_y_); l_x_ = l_y_ = None
|
||||
child: "f32[8]" = mul * 2; mul = None
|
||||
return (child,)
|
||||
|
||||
class invoke_subgraph_1(torch.nn.Module):
|
||||
def forward(self, a: "f32[8]", l_y_: "f32[8]"):
|
||||
mul: "f32[8]" = torch.mul(a, l_y_); a = l_y_ = None
|
||||
child: "f32[8]" = mul * 3; mul = None
|
||||
return (child,)
|
||||
""",
|
||||
)
|
||||
|
||||
def test_normalize_gm(self):
|
||||
def gn(x, y):
|
||||
# Different graph give different names to intermediate nodes
|
||||
for _ in range(5):
|
||||
x = x * y
|
||||
return x
|
||||
|
||||
def fn(x, y):
|
||||
for _ in range(5):
|
||||
x = invoke_subgraph(gn, None, (x, y))
|
||||
return x
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
|
||||
x = torch.randn(8, requires_grad=True)
|
||||
y = torch.randn(8, requires_grad=True)
|
||||
|
||||
opt_fn(x, y)
|
||||
|
||||
if not TEST_WITH_CROSSREF:
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[8]", L_y_: "f32[8]"):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
|
||||
invoke_subgraph_0 = self.invoke_subgraph_0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_, l_y_)); invoke_subgraph_0 = l_x_ = None
|
||||
x: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
invoke_subgraph_1 = self.invoke_subgraph_0
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', (x, l_y_)); invoke_subgraph_1 = x = None
|
||||
x_1: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
invoke_subgraph_3 = self.invoke_subgraph_0
|
||||
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_3, 'invoke_subgraph_0', (x_1, l_y_)); invoke_subgraph_3 = x_1 = None
|
||||
x_2: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
|
||||
invoke_subgraph_5 = self.invoke_subgraph_0
|
||||
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_5, 'invoke_subgraph_0', (x_2, l_y_)); invoke_subgraph_5 = x_2 = None
|
||||
x_3: "f32[8]" = invoke_subgraph_6[0]; invoke_subgraph_6 = None
|
||||
invoke_subgraph_7 = self.invoke_subgraph_0
|
||||
invoke_subgraph_8 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_7, 'invoke_subgraph_0', (x_3, l_y_)); invoke_subgraph_7 = x_3 = l_y_ = None
|
||||
x_4: "f32[8]" = invoke_subgraph_8[0]; invoke_subgraph_8 = None
|
||||
return (x_4,)
|
||||
|
||||
class invoke_subgraph_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"):
|
||||
x: "f32[8]" = l_x_ * l_y_; l_x_ = None
|
||||
x_1: "f32[8]" = x * l_y_; x = None
|
||||
x_2: "f32[8]" = x_1 * l_y_; x_1 = None
|
||||
x_3: "f32[8]" = x_2 * l_y_; x_2 = None
|
||||
x_4: "f32[8]" = x_3 * l_y_; x_3 = l_y_ = None
|
||||
return (x_4,)
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -417,6 +417,7 @@ class OutputGraph:
|
|||
)
|
||||
|
||||
self.guard_on_key_order: Set[str] = set()
|
||||
self.seen_invoke_subgraphs: Dict[str, str] = {}
|
||||
|
||||
def install_builtins_dict_in_fglobals(self):
|
||||
# f_globals["__builtins__"] can be a dict or a module. This is an
|
||||
|
|
|
|||
|
|
@ -3250,6 +3250,7 @@ MOD_INLINELIST = [
|
|||
"torch._functorch.functional_call",
|
||||
"torch._functorch.vmap",
|
||||
"torch._higher_order_ops.associative_scan",
|
||||
"torch._higher_order_ops.invoke_subgraph",
|
||||
"torch._higher_order_ops.scan",
|
||||
"torch._higher_order_ops.strict_mode",
|
||||
"torch._higher_order_ops.while_loop",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
|
|
@ -18,6 +19,7 @@ from torch._dynamo.variables.functions import UserFunctionVariable
|
|||
from torch._dynamo.variables.tensor import SymNodeVariable
|
||||
from torch._guards import Source
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch.fx.node import map_arg
|
||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
|
@ -630,6 +632,8 @@ class TorchHigherOrderOperatorVariable(VariableTracker):
|
|||
or value.__name__ == "auto_functionalized_v2"
|
||||
):
|
||||
return AutoFunctionalizeHigherOrderVariable(value, source, **kwargs)
|
||||
elif value.__name__ == "invoke_subgraph":
|
||||
return InvokeSubgraphHigherOrderVariable(value, source, **kwargs)
|
||||
else:
|
||||
unimplemented(f"HigherOrderOperator {value.__name__}")
|
||||
|
||||
|
|
@ -1064,11 +1068,15 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
args=(
|
||||
VariableTracker.build(
|
||||
tx,
|
||||
leaf.size
|
||||
if leaf.size is not None
|
||||
else BuiltinVariable(getattr)
|
||||
.call_function(tx, [leaf, ConstantVariable.create("shape")], {})
|
||||
.items,
|
||||
(
|
||||
leaf.size
|
||||
if leaf.size is not None
|
||||
else BuiltinVariable(getattr)
|
||||
.call_function(
|
||||
tx, [leaf, ConstantVariable.create("shape")], {}
|
||||
)
|
||||
.items
|
||||
),
|
||||
),
|
||||
),
|
||||
kwargs={
|
||||
|
|
@ -2526,3 +2534,137 @@ def maybe_positional_arg_names(func):
|
|||
else:
|
||||
result.append(name)
|
||||
return result
|
||||
|
||||
|
||||
def canonicalize(gmod, root_gmod):
|
||||
# autograd_cache_key is sensitive to the name of the placeholder and intermediate nodes.
|
||||
# So, we first canonicalize it.
|
||||
new_graph = torch.fx.Graph()
|
||||
env = {}
|
||||
|
||||
placeholder_counter = itertools.count(0)
|
||||
|
||||
def next_placeholder_name():
|
||||
nonlocal placeholder_counter
|
||||
return f"placeholder_{next(placeholder_counter)}"
|
||||
|
||||
node_counter = itertools.count(0)
|
||||
|
||||
def next_node_name():
|
||||
nonlocal node_counter
|
||||
return f"node_{next(node_counter)}"
|
||||
|
||||
for node in gmod.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
env[node] = new_graph.placeholder(next_placeholder_name())
|
||||
else:
|
||||
# Can't use node_copy because node.name will not be unique.
|
||||
args = map_arg(node.args, lambda x: env[x])
|
||||
kwargs = map_arg(node.kwargs, lambda x: env[x])
|
||||
env[node] = new_graph.create_node(
|
||||
node.op, node.target, args, kwargs, next_node_name(), node.type
|
||||
)
|
||||
env[node].meta = copy.copy(node.meta)
|
||||
|
||||
new_graph.lint()
|
||||
new_gmod = torch.fx.GraphModule(root_gmod, new_graph)
|
||||
return new_gmod
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_dummy_aot_autograd_config():
|
||||
from torch._functorch._aot_autograd.schemas import AOTConfig
|
||||
|
||||
return AOTConfig(
|
||||
fw_compiler=None,
|
||||
bw_compiler=None,
|
||||
inference_compiler=None,
|
||||
partition_fn=None,
|
||||
decompositions={},
|
||||
num_params_buffers=0,
|
||||
aot_id=0,
|
||||
keep_inference_input_mutations=False,
|
||||
dynamic_shapes=True,
|
||||
aot_autograd_arg_pos_to_source=None,
|
||||
is_export=False,
|
||||
no_tangents=False,
|
||||
enable_log=False,
|
||||
)
|
||||
|
||||
|
||||
def hash_graph_and_inputs(tx, gmod, fake_inputs):
|
||||
# Here, we use the existing autograd_cache_key infrastructure to hash the
|
||||
# graph and fake inputs.
|
||||
|
||||
# TODO(anijain2305) - Consider reorganizing autograd_cache_key such that the
|
||||
# namespaces seem more intuitive. It seems somewhat confusing that we are
|
||||
# calling an API from aot_autograd here.
|
||||
from torch._functorch._aot_autograd.autograd_cache import autograd_cache_key
|
||||
|
||||
# autograd_cache_key is sensitive to the name of the placeholder nodes.
|
||||
# So, we first canonicalize it.
|
||||
canonicalized_gmod = canonicalize(gmod, tx.output.nn_modules)
|
||||
config = get_dummy_aot_autograd_config()
|
||||
|
||||
key, _ = autograd_cache_key(canonicalized_gmod, fake_inputs, config, {})
|
||||
return key
|
||||
|
||||
|
||||
class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
|
||||
def install_subgraph_in_output_graph(
|
||||
self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name="invoke_subgraph"
|
||||
):
|
||||
# Check if the subgraph from speculate_subgraph (body_gmod) and the fake
|
||||
# inputs have already been seen before. If yes, the subgraph is already
|
||||
# installed in the output graph and we can just access the subgraph
|
||||
# using the saved attr name.
|
||||
|
||||
fake_inputs = [arg.as_proxy().node.meta["example_value"] for arg in fn_args_vt]
|
||||
|
||||
key = hash_graph_and_inputs(tx, body_gmod, fake_inputs)
|
||||
|
||||
if key in tx.output.seen_invoke_subgraphs:
|
||||
return tx.output.seen_invoke_subgraphs[key]
|
||||
|
||||
body_name = super().install_subgraph_in_output_graph(
|
||||
tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name
|
||||
)
|
||||
tx.output.seen_invoke_subgraphs[key] = body_name
|
||||
return body_name
|
||||
|
||||
def call_function(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
args: "List[VariableTracker]",
|
||||
kwargs: "Dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
# This flattens the kwargs into lifted args
|
||||
(
|
||||
p_args,
|
||||
p_kwargs,
|
||||
example_value,
|
||||
body_r,
|
||||
treespec,
|
||||
_,
|
||||
body_name,
|
||||
) = self.create_wrapped_node(
|
||||
tx, args[0], args[2].items, kwargs, "invoke_subgraph"
|
||||
)
|
||||
|
||||
if len(p_kwargs) > 0:
|
||||
unimplemented("kwargs should have been flattened into lifted args")
|
||||
|
||||
flat_example_value = pytree.tree_map_only(
|
||||
torch.fx.Proxy,
|
||||
lambda a: a.node.meta["example_value"],
|
||||
body_r.as_proxy(),
|
||||
)
|
||||
|
||||
p_args = (
|
||||
p_args[0],
|
||||
body_name,
|
||||
p_args[1:],
|
||||
)
|
||||
return _call_function_and_unflatten_output(
|
||||
tx, self.value, tuple(p_args), p_kwargs, flat_example_value, treespec
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,12 +4,14 @@ from torch._higher_order_ops.flex_attention import (
|
|||
flex_attention_backward,
|
||||
)
|
||||
from torch._higher_order_ops.hints_wrap import hints_wrapper
|
||||
from torch._higher_order_ops.invoke_subgraph import invoke_subgraph
|
||||
from torch._higher_order_ops.while_loop import while_loop
|
||||
|
||||
|
||||
__all__ = [
|
||||
"cond",
|
||||
"while_loop",
|
||||
"invoke_subgraph",
|
||||
"flex_attention",
|
||||
"flex_attention_backward",
|
||||
"hints_wrapper",
|
||||
|
|
|
|||
235
torch/_higher_order_ops/invoke_subgraph.py
Normal file
235
torch/_higher_order_ops/invoke_subgraph.py
Normal file
|
|
@ -0,0 +1,235 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey
|
||||
from torch._dispatch.python import suspend_functionalization
|
||||
from torch._higher_order_ops.utils import (
|
||||
_from_fun,
|
||||
_maybe_reenter_make_fx,
|
||||
clone_outputs_aliasing_inputs,
|
||||
get_dummy_aot_autograd_config,
|
||||
prepare_fw_with_masks,
|
||||
reenter_make_fx,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch._subclasses.functional_tensor import disable_functional_mode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
disable_proxy_modes_tracing,
|
||||
ProxyTorchDispatchMode,
|
||||
track_tensor_tree,
|
||||
)
|
||||
from torch.fx.graph_module import GraphModule
|
||||
|
||||
|
||||
invoke_subgraph_counter = 0
|
||||
|
||||
|
||||
class InvokeSubgraphHOP(HigherOrderOperator):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("invoke_subgraph")
|
||||
|
||||
# identifier is setup by upper part of the stack. This helps us in
|
||||
# identifying two invoke_subgraph calls have same subgraph.
|
||||
def __call__(
|
||||
self,
|
||||
subgraph: GraphModule,
|
||||
identifier: Optional[str],
|
||||
operands: Union[
|
||||
List[Union[torch.Tensor, torch.SymInt]],
|
||||
Tuple[Union[torch.Tensor, torch.SymInt]],
|
||||
],
|
||||
):
|
||||
assert identifier is None or isinstance(
|
||||
identifier, str
|
||||
), "identifier must be a None or a string"
|
||||
|
||||
assert isinstance(
|
||||
operands, (list, tuple)
|
||||
), f"invoke_subgraph operands must be a list or tuple of tensors and SymInts {operands}"
|
||||
assert all(
|
||||
isinstance(o, (torch.Tensor, torch.SymInt)) for o in operands
|
||||
), f"invoke_subgraph operands must be a list of tensors and SymInts {operands}"
|
||||
|
||||
return super().__call__(subgraph, identifier, operands)
|
||||
|
||||
|
||||
invoke_subgraph = InvokeSubgraphHOP()
|
||||
|
||||
|
||||
def trace_joint_graph(fn, fw_inputs, fw_outputs):
|
||||
"""
|
||||
Naively trace out a joint graph. This simplifies the reconstruction of joint
|
||||
graph in the min-cut partitioner later on.
|
||||
"""
|
||||
from torch._functorch.aot_autograd import create_joint
|
||||
|
||||
dummy_aot_config = get_dummy_aot_autograd_config()
|
||||
|
||||
def joint_fn(*primals_and_tangents):
|
||||
primals = primals_and_tangents[: len(fw_inputs)]
|
||||
tangents = primals_and_tangents[len(fw_inputs) :]
|
||||
|
||||
fw_outs, grads = create_joint(
|
||||
prepare_fw_with_masks(fn), aot_config=dummy_aot_config
|
||||
)(primals, tangents)
|
||||
|
||||
maybe_clone = clone_outputs_aliasing_inputs(primals_and_tangents)
|
||||
|
||||
return pytree.tree_map(maybe_clone, list(fw_outs) + grads)
|
||||
|
||||
primals = list(fw_inputs)
|
||||
# This assumes that the tangent strides match fw_outputs strides. Check the
|
||||
# InvokeSubgraphAutogradOp backward op for the contiguous call.
|
||||
tangents = [_from_fun(out) for out in fw_outputs]
|
||||
|
||||
joint_operands = primals + tangents
|
||||
|
||||
return _maybe_reenter_make_fx(joint_fn)(*joint_operands)
|
||||
|
||||
|
||||
def create_fw_bw_graph(subgraph, operands):
|
||||
with suspend_functionalization(), disable_functional_mode():
|
||||
with disable_proxy_modes_tracing():
|
||||
# args are functional tensors, generate some example tensors
|
||||
fw_inputs = pytree.tree_map(_from_fun, operands)
|
||||
|
||||
fw_outputs = pytree.tree_map(_from_fun, subgraph(*fw_inputs))
|
||||
if any(
|
||||
not isinstance(out, torch.Tensor)
|
||||
for out in fw_outputs
|
||||
if out is not None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Expect outputs of invoke_subgraph to only contains tensors or None. "
|
||||
f"Got types {[type(out) for out in fw_outputs]}."
|
||||
)
|
||||
|
||||
# Trace the forward subgraph
|
||||
fw_graph = _maybe_reenter_make_fx(subgraph)(*fw_inputs)
|
||||
|
||||
# Trace the joint graph and assign it to the bwd graph
|
||||
bw_graph = trace_joint_graph(
|
||||
subgraph,
|
||||
fw_inputs,
|
||||
fw_outputs,
|
||||
)
|
||||
return fw_graph, bw_graph, len(fw_outputs)
|
||||
|
||||
|
||||
class InvokeSubgraphAutogradOp(torch.autograd.Function):
|
||||
"""
|
||||
This autograd function op is to stash the backward graph in the ctx while
|
||||
running forward.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, fw_graph, bw_graph, identifier, num_fw_outs, *operands):
|
||||
ctx._fw_graph = fw_graph
|
||||
ctx._bw_graph = bw_graph
|
||||
ctx._identifier = identifier
|
||||
ctx._num_fw_outs = num_fw_outs
|
||||
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
out = invoke_subgraph(
|
||||
fw_graph,
|
||||
f"___forward_{identifier}",
|
||||
operands,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(*operands)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outs):
|
||||
bw_graph = ctx._bw_graph
|
||||
identifier = ctx._identifier
|
||||
primals = ctx.saved_tensors
|
||||
num_fw_outs = ctx._num_fw_outs
|
||||
|
||||
# While tracing we made the assumption that tangents are contiguous. So,
|
||||
# force the grad_outs to be contiguous.
|
||||
contiguous_grad_outs = tuple([o.contiguous() for o in grad_outs])
|
||||
|
||||
# bw_graph is a joint graph with signature (*primals_and_tangents) and
|
||||
# returns (*fw_outs_and_grads). To get the grads, we use the num_fw_outs
|
||||
# to extract the grads.
|
||||
primals_and_tangents = primals + contiguous_grad_outs
|
||||
grads = invoke_subgraph(
|
||||
bw_graph, f"___backward_{identifier}", primals_and_tangents
|
||||
)[num_fw_outs:]
|
||||
return None, None, None, None, *grads
|
||||
|
||||
|
||||
@invoke_subgraph.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def _(subgraph, identifier, operands):
|
||||
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
||||
|
||||
mode = _get_current_dispatch_mode()
|
||||
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
|
||||
return subgraph(*operands)
|
||||
|
||||
|
||||
@invoke_subgraph.py_impl(DispatchKey.Autograd)
|
||||
def _(subgraph, identifier, operands):
|
||||
if not torch.is_grad_enabled():
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
return invoke_subgraph(subgraph, identifier, operands)
|
||||
|
||||
# A shortcut for the case where all inputs don't require gradient,
|
||||
# we skip tracing the forward and backward graph.
|
||||
if pytree.tree_all_only(
|
||||
torch.Tensor,
|
||||
lambda t: not t.requires_grad, # type: ignore[union-attr]
|
||||
operands,
|
||||
):
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
return invoke_subgraph(subgraph, identifier, operands)
|
||||
|
||||
fw_graph, bw_graph, num_fw_outs = create_fw_bw_graph(subgraph, operands)
|
||||
# TODO(anijain2305) - Implement caching of autograd function op.
|
||||
return InvokeSubgraphAutogradOp.apply(
|
||||
fw_graph, bw_graph, identifier, num_fw_outs, *operands
|
||||
)
|
||||
|
||||
|
||||
@invoke_subgraph.py_functionalize_impl
|
||||
def _(ctx, subgraph, identifier, operands):
|
||||
unwrapped_operands = ctx.unwrap_tensors(operands)
|
||||
with ctx.redispatch_to_next() as m:
|
||||
# NB: There is an assumption that subgraph does not mutate inputs and
|
||||
# there is no aliasing. Its Dynamo responsibility to prevent formation
|
||||
# of invoke_subgraph ops if input aliasing/mutation is detected.
|
||||
functionalized_subgraph = ctx.functionalize(subgraph)
|
||||
out = invoke_subgraph(functionalized_subgraph, identifier, unwrapped_operands)
|
||||
return ctx.wrap_tensors(out)
|
||||
|
||||
|
||||
@invoke_subgraph.py_impl(FakeTensorMode)
|
||||
def _(mode, subgraph, identifier, operands):
|
||||
# TODO(anijain2305) - Implement fake tensor caching.
|
||||
return subgraph(*operands)
|
||||
|
||||
|
||||
@invoke_subgraph.py_impl(ProxyTorchDispatchMode)
|
||||
def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, operands):
|
||||
# TODO(anijain2305) - Implement proxy tensor caching.
|
||||
example_out = invoke_subgraph(subgraph, identifier, operands)
|
||||
graph = reenter_make_fx(subgraph)(*operands)
|
||||
assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
|
||||
qualname = proxy_mode.tracer.get_fresh_qualname("repeated_subgraph")
|
||||
proxy_mode.tracer.root.register_module(qualname, graph)
|
||||
|
||||
node_args = (graph, identifier, operands)
|
||||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function", invoke_subgraph, proxy_args, {}
|
||||
)
|
||||
return track_tensor_tree(
|
||||
example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
|
||||
)
|
||||
|
|
@ -444,3 +444,17 @@ def saved_tensors_and_symints(ctx):
|
|||
s_idx += 1
|
||||
assert t_idx + s_idx == len(ctx.pos)
|
||||
return tuple(args)
|
||||
|
||||
|
||||
def get_dummy_aot_autograd_config():
|
||||
from torch._functorch.aot_autograd import AOTConfig
|
||||
|
||||
return AOTConfig(
|
||||
fw_compiler=None, # type: ignore[arg-type]
|
||||
bw_compiler=None, # type: ignore[arg-type]
|
||||
partition_fn=None, # type: ignore[arg-type]
|
||||
decompositions={},
|
||||
num_params_buffers=0,
|
||||
aot_id=0,
|
||||
keep_inference_input_mutations=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -102,6 +102,17 @@ def simple_cond(x):
|
|||
return torch.cond(x.sum() > 2, lambda x: (x.cos(),), lambda x: (x.sin(),), [x])
|
||||
|
||||
|
||||
def sample_inputs_invoke_subgraph(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = functools.partial(
|
||||
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
|
||||
)
|
||||
yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2))
|
||||
|
||||
def simple_invoke_subgraph(x):
|
||||
def fn(x):
|
||||
return (torch.sin(x),)
|
||||
return torch._higher_order_ops.invoke_subgraph(fn, None, (x,))
|
||||
|
||||
def sample_inputs_auto_functionalize(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = functools.partial(
|
||||
make_tensor, device=device, dtype=dtype, requires_grad=False
|
||||
|
|
@ -151,6 +162,21 @@ def simple_while_loop(iter_t, x):
|
|||
|
||||
|
||||
hop_db = [
|
||||
OpInfo(
|
||||
name="invoke_subgraph",
|
||||
variant_test_name="simple",
|
||||
op=simple_invoke_subgraph,
|
||||
sample_inputs_func=sample_inputs_invoke_subgraph,
|
||||
dtypes=all_types_and(torch.bool, torch.half),
|
||||
supports_out=False,
|
||||
check_batched_grad=False,
|
||||
check_batched_gradgrad=False,
|
||||
check_batched_forward_grad=False,
|
||||
check_inplace_batched_forward_grad=False,
|
||||
supports_autograd=True,
|
||||
# "torch.compile with aot_autograd does not currently support double backward."
|
||||
supports_gradgrad=False,
|
||||
),
|
||||
OpInfo(
|
||||
name="map",
|
||||
variant_test_name="simple",
|
||||
|
|
|
|||
Loading…
Reference in a new issue