[hierarchical-compilation][hop] Introduce invoke_subgraph (#137538)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137538
Approved by: https://github.com/zou3519
This commit is contained in:
Animesh Jain 2024-10-21 16:33:43 -07:00 committed by PyTorch MergeBot
parent 046f02d2de
commit 4dd4d38ca9
8 changed files with 760 additions and 5 deletions

View file

@ -0,0 +1,334 @@
# Owner(s): ["module: higher order operators"]
# flake8: noqa: B950
import torch
import torch._dynamo
import torch._functorch
import torch._inductor
import torch._inductor.decomposition
from functorch.compile import aot_function, nop
from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm
from torch._higher_order_ops import invoke_subgraph
from torch.testing._internal.common_utils import (
run_tests,
skipIfTorchDynamo,
TEST_WITH_CROSSREF,
TestCase,
)
@skipIfTorchDynamo("Not a torch._dynamo test")
class TestInvokeSubgraph(TestCase):
def test_simple(self):
def gn(x, y):
return (torch.mul(x, y),)
def fn(x, y):
return invoke_subgraph(gn, None, (x, y))[0]
x = torch.randn(8, requires_grad=True)
y = torch.randn(8, requires_grad=True)
ref = gn(x, y)[0]
x_clone = x.clone().detach().requires_grad_(True)
y_clone = y.clone().detach().requires_grad_(True)
res = fn(x_clone, y_clone)
# Run backward
ref.sum().backward()
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
self.assertEqual(y.grad, y_clone.grad)
def test_aot_function(self):
def gn(x, y):
return (torch.mul(x, y),)
def fn(x, y):
return invoke_subgraph(gn, None, (x, y))[0]
x = torch.randn(8, requires_grad=True)
y = torch.randn(8, requires_grad=True)
ref = gn(x, y)[0]
x_clone = x.clone().detach().requires_grad_(True)
y_clone = y.clone().detach().requires_grad_(True)
aot_fn = aot_function(fn, nop)
res = aot_fn(x_clone, y_clone)
# Run backward
ref.sum().backward()
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
self.assertEqual(y.grad, y_clone.grad)
def test_multiple(self):
n_layers = 2
def cos(x):
return (torch.cos(x),)
def sin(x):
return (torch.sin(x),)
def fn(x):
a = invoke_subgraph(cos, None, (x,))[0]
b = invoke_subgraph(sin, None, (a,))[0]
return invoke_subgraph(cos, None, (b,))[0]
x = torch.randn(8, requires_grad=True)
ref = fn(x)
aot_fn = aot_function(fn, nop)
res = aot_fn(x)
self.assertEqual(ref, res)
def test_differing_strides_for_grad_outs(self):
class CustomOp(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return torch.sin(x)
@staticmethod
def backward(ctx, grad_out):
a = grad_out.view(12, 5)
return torch.cos(torch.reshape(a, (3, 4, 5)))
def gn(x):
return (CustomOp.apply(x),)
def fn(x):
a = invoke_subgraph(gn, None, (x,))[0]
# Force stride changes so that backward view causes a failure if
# contiguous not called.
b = torch.permute(a, (0, 2, 1))
return b
x = torch.randn(3, 4, 5, requires_grad=True)
ref = torch.permute(gn(x)[0], (0, 2, 1))
x_clone = x.clone().detach().requires_grad_(True)
aot_fn = aot_function(fn, nop)
res = aot_fn(x_clone)
# Run backward
ref.sum().backward()
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
@skipIfTorchDynamo("Not a torch._dynamo test")
class TestInvokeSubgraphCompile(TestCase):
def test_simple(self):
def gn(x, y):
return (torch.mul(x, y),)
def fn(x, y):
return invoke_subgraph(gn, None, (x, y))[0]
x = torch.randn(8, requires_grad=True)
y = torch.randn(8, requires_grad=True)
ref = gn(x, y)[0]
x_clone = x.clone().detach().requires_grad_(True)
y_clone = y.clone().detach().requires_grad_(True)
res = torch.compile(fn, backend="eager", fullgraph=True)(x_clone, y_clone)
# Run backward
ref.sum().backward()
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
self.assertEqual(y.grad, y_clone.grad)
def test_multiple(self):
def gn(x, y):
return (torch.mul(x, y),)
def fn(x, y):
a = invoke_subgraph(gn, None, (x, y))[0]
return invoke_subgraph(gn, None, (a, y))[0]
x = torch.randn(8, requires_grad=True)
y = torch.randn(8, requires_grad=True)
ref = fn(x, y)
x_clone = x.clone().detach().requires_grad_(True)
y_clone = y.clone().detach().requires_grad_(True)
backend = AotEagerAndRecordGraphs()
res = torch.compile(fn, backend=backend, fullgraph=True)(x_clone, y_clone)
# Run backward
ref.sum().backward()
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
self.assertEqual(y.grad, y_clone.grad)
# Check that the Dynamo graph has just one subgraph module
self.assertEqual(len(backend.graphs), 1)
subgraph_attr_names = set()
for node in backend.graphs[0].graph.nodes:
if node.op == "get_attr":
subgraph_attr_names.add(node.target)
self.assertEqual(len(subgraph_attr_names), 1)
if not TEST_WITH_CROSSREF:
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[8]", L_y_: "f32[8]"):
l_x_ = L_x_
l_y_ = L_y_
invoke_subgraph_0 = self.invoke_subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_, l_y_)); invoke_subgraph_0 = l_x_ = None
a: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
invoke_subgraph_1 = self.invoke_subgraph_0
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', (a, l_y_)); invoke_subgraph_1 = a = l_y_ = None
getitem_1: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
return (getitem_1,)
class invoke_subgraph_0(torch.nn.Module):
def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"):
child: "f32[8]" = torch.mul(l_x_, l_y_); l_x_ = l_y_ = None
return (child,)
""",
)
def test_nonlocal_update(self):
counter = 2
def gn(x, y):
nonlocal counter
return (torch.mul(x, y) * counter,)
def fn(x, y):
nonlocal counter
counter = 2
a = invoke_subgraph(gn, None, (x, y))[0]
counter = 3
return invoke_subgraph(gn, None, (a, y))[0]
x = torch.randn(8, requires_grad=True)
y = torch.randn(8, requires_grad=True)
ref = fn(x, y)
x_clone = x.clone().detach().requires_grad_(True)
y_clone = y.clone().detach().requires_grad_(True)
res = torch.compile(fn, backend="eager", fullgraph=True)(x_clone, y_clone)
# Run backward
ref.sum().backward()
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
self.assertEqual(y.grad, y_clone.grad)
torch._dynamo.reset()
backend = AotEagerAndRecordGraphs()
torch.compile(fn, backend=backend, fullgraph=True)(x_clone, y_clone)
if not TEST_WITH_CROSSREF:
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[8]", L_y_: "f32[8]"):
l_x_ = L_x_
l_y_ = L_y_
invoke_subgraph_0 = self.invoke_subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_, l_y_)); invoke_subgraph_0 = l_x_ = None
a: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
invoke_subgraph_1 = self.invoke_subgraph_1
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_1', (a, l_y_)); invoke_subgraph_1 = a = l_y_ = None
getitem_1: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
return (getitem_1,)
class invoke_subgraph_0(torch.nn.Module):
def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"):
mul: "f32[8]" = torch.mul(l_x_, l_y_); l_x_ = l_y_ = None
child: "f32[8]" = mul * 2; mul = None
return (child,)
class invoke_subgraph_1(torch.nn.Module):
def forward(self, a: "f32[8]", l_y_: "f32[8]"):
mul: "f32[8]" = torch.mul(a, l_y_); a = l_y_ = None
child: "f32[8]" = mul * 3; mul = None
return (child,)
""",
)
def test_normalize_gm(self):
def gn(x, y):
# Different graph give different names to intermediate nodes
for _ in range(5):
x = x * y
return x
def fn(x, y):
for _ in range(5):
x = invoke_subgraph(gn, None, (x, y))
return x
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
x = torch.randn(8, requires_grad=True)
y = torch.randn(8, requires_grad=True)
opt_fn(x, y)
if not TEST_WITH_CROSSREF:
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[8]", L_y_: "f32[8]"):
l_x_ = L_x_
l_y_ = L_y_
invoke_subgraph_0 = self.invoke_subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_, l_y_)); invoke_subgraph_0 = l_x_ = None
x: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
invoke_subgraph_1 = self.invoke_subgraph_0
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', (x, l_y_)); invoke_subgraph_1 = x = None
x_1: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
invoke_subgraph_3 = self.invoke_subgraph_0
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_3, 'invoke_subgraph_0', (x_1, l_y_)); invoke_subgraph_3 = x_1 = None
x_2: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
invoke_subgraph_5 = self.invoke_subgraph_0
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_5, 'invoke_subgraph_0', (x_2, l_y_)); invoke_subgraph_5 = x_2 = None
x_3: "f32[8]" = invoke_subgraph_6[0]; invoke_subgraph_6 = None
invoke_subgraph_7 = self.invoke_subgraph_0
invoke_subgraph_8 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_7, 'invoke_subgraph_0', (x_3, l_y_)); invoke_subgraph_7 = x_3 = l_y_ = None
x_4: "f32[8]" = invoke_subgraph_8[0]; invoke_subgraph_8 = None
return (x_4,)
class invoke_subgraph_0(torch.nn.Module):
def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"):
x: "f32[8]" = l_x_ * l_y_; l_x_ = None
x_1: "f32[8]" = x * l_y_; x = None
x_2: "f32[8]" = x_1 * l_y_; x_1 = None
x_3: "f32[8]" = x_2 * l_y_; x_2 = None
x_4: "f32[8]" = x_3 * l_y_; x_3 = l_y_ = None
return (x_4,)
""",
)
if __name__ == "__main__":
run_tests()

View file

@ -417,6 +417,7 @@ class OutputGraph:
)
self.guard_on_key_order: Set[str] = set()
self.seen_invoke_subgraphs: Dict[str, str] = {}
def install_builtins_dict_in_fglobals(self):
# f_globals["__builtins__"] can be a dict or a module. This is an

View file

@ -3250,6 +3250,7 @@ MOD_INLINELIST = [
"torch._functorch.functional_call",
"torch._functorch.vmap",
"torch._higher_order_ops.associative_scan",
"torch._higher_order_ops.invoke_subgraph",
"torch._higher_order_ops.scan",
"torch._higher_order_ops.strict_mode",
"torch._higher_order_ops.while_loop",

View file

@ -1,6 +1,7 @@
# mypy: ignore-errors
import contextlib
import copy
import functools
import inspect
import itertools
@ -18,6 +19,7 @@ from torch._dynamo.variables.functions import UserFunctionVariable
from torch._dynamo.variables.tensor import SymNodeVariable
from torch._guards import Source
from torch._ops import HigherOrderOperator
from torch.fx.node import map_arg
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils import _pytree as pytree
@ -630,6 +632,8 @@ class TorchHigherOrderOperatorVariable(VariableTracker):
or value.__name__ == "auto_functionalized_v2"
):
return AutoFunctionalizeHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "invoke_subgraph":
return InvokeSubgraphHigherOrderVariable(value, source, **kwargs)
else:
unimplemented(f"HigherOrderOperator {value.__name__}")
@ -1064,11 +1068,15 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
args=(
VariableTracker.build(
tx,
leaf.size
if leaf.size is not None
else BuiltinVariable(getattr)
.call_function(tx, [leaf, ConstantVariable.create("shape")], {})
.items,
(
leaf.size
if leaf.size is not None
else BuiltinVariable(getattr)
.call_function(
tx, [leaf, ConstantVariable.create("shape")], {}
)
.items
),
),
),
kwargs={
@ -2526,3 +2534,137 @@ def maybe_positional_arg_names(func):
else:
result.append(name)
return result
def canonicalize(gmod, root_gmod):
# autograd_cache_key is sensitive to the name of the placeholder and intermediate nodes.
# So, we first canonicalize it.
new_graph = torch.fx.Graph()
env = {}
placeholder_counter = itertools.count(0)
def next_placeholder_name():
nonlocal placeholder_counter
return f"placeholder_{next(placeholder_counter)}"
node_counter = itertools.count(0)
def next_node_name():
nonlocal node_counter
return f"node_{next(node_counter)}"
for node in gmod.graph.nodes:
if node.op == "placeholder":
env[node] = new_graph.placeholder(next_placeholder_name())
else:
# Can't use node_copy because node.name will not be unique.
args = map_arg(node.args, lambda x: env[x])
kwargs = map_arg(node.kwargs, lambda x: env[x])
env[node] = new_graph.create_node(
node.op, node.target, args, kwargs, next_node_name(), node.type
)
env[node].meta = copy.copy(node.meta)
new_graph.lint()
new_gmod = torch.fx.GraphModule(root_gmod, new_graph)
return new_gmod
@functools.lru_cache(None)
def get_dummy_aot_autograd_config():
from torch._functorch._aot_autograd.schemas import AOTConfig
return AOTConfig(
fw_compiler=None,
bw_compiler=None,
inference_compiler=None,
partition_fn=None,
decompositions={},
num_params_buffers=0,
aot_id=0,
keep_inference_input_mutations=False,
dynamic_shapes=True,
aot_autograd_arg_pos_to_source=None,
is_export=False,
no_tangents=False,
enable_log=False,
)
def hash_graph_and_inputs(tx, gmod, fake_inputs):
# Here, we use the existing autograd_cache_key infrastructure to hash the
# graph and fake inputs.
# TODO(anijain2305) - Consider reorganizing autograd_cache_key such that the
# namespaces seem more intuitive. It seems somewhat confusing that we are
# calling an API from aot_autograd here.
from torch._functorch._aot_autograd.autograd_cache import autograd_cache_key
# autograd_cache_key is sensitive to the name of the placeholder nodes.
# So, we first canonicalize it.
canonicalized_gmod = canonicalize(gmod, tx.output.nn_modules)
config = get_dummy_aot_autograd_config()
key, _ = autograd_cache_key(canonicalized_gmod, fake_inputs, config, {})
return key
class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
def install_subgraph_in_output_graph(
self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name="invoke_subgraph"
):
# Check if the subgraph from speculate_subgraph (body_gmod) and the fake
# inputs have already been seen before. If yes, the subgraph is already
# installed in the output graph and we can just access the subgraph
# using the saved attr name.
fake_inputs = [arg.as_proxy().node.meta["example_value"] for arg in fn_args_vt]
key = hash_graph_and_inputs(tx, body_gmod, fake_inputs)
if key in tx.output.seen_invoke_subgraphs:
return tx.output.seen_invoke_subgraphs[key]
body_name = super().install_subgraph_in_output_graph(
tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name
)
tx.output.seen_invoke_subgraphs[key] = body_name
return body_name
def call_function(
self,
tx: "InstructionTranslator",
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
# This flattens the kwargs into lifted args
(
p_args,
p_kwargs,
example_value,
body_r,
treespec,
_,
body_name,
) = self.create_wrapped_node(
tx, args[0], args[2].items, kwargs, "invoke_subgraph"
)
if len(p_kwargs) > 0:
unimplemented("kwargs should have been flattened into lifted args")
flat_example_value = pytree.tree_map_only(
torch.fx.Proxy,
lambda a: a.node.meta["example_value"],
body_r.as_proxy(),
)
p_args = (
p_args[0],
body_name,
p_args[1:],
)
return _call_function_and_unflatten_output(
tx, self.value, tuple(p_args), p_kwargs, flat_example_value, treespec
)

View file

@ -4,12 +4,14 @@ from torch._higher_order_ops.flex_attention import (
flex_attention_backward,
)
from torch._higher_order_ops.hints_wrap import hints_wrapper
from torch._higher_order_ops.invoke_subgraph import invoke_subgraph
from torch._higher_order_ops.while_loop import while_loop
__all__ = [
"cond",
"while_loop",
"invoke_subgraph",
"flex_attention",
"flex_attention_backward",
"hints_wrapper",

View file

@ -0,0 +1,235 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import List, Optional, Tuple, Union
import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._dispatch.python import suspend_functionalization
from torch._higher_order_ops.utils import (
_from_fun,
_maybe_reenter_make_fx,
clone_outputs_aliasing_inputs,
get_dummy_aot_autograd_config,
prepare_fw_with_masks,
reenter_make_fx,
)
from torch._ops import HigherOrderOperator
from torch._subclasses import FakeTensorMode
from torch._subclasses.functional_tensor import disable_functional_mode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
ProxyTorchDispatchMode,
track_tensor_tree,
)
from torch.fx.graph_module import GraphModule
invoke_subgraph_counter = 0
class InvokeSubgraphHOP(HigherOrderOperator):
def __init__(self) -> None:
super().__init__("invoke_subgraph")
# identifier is setup by upper part of the stack. This helps us in
# identifying two invoke_subgraph calls have same subgraph.
def __call__(
self,
subgraph: GraphModule,
identifier: Optional[str],
operands: Union[
List[Union[torch.Tensor, torch.SymInt]],
Tuple[Union[torch.Tensor, torch.SymInt]],
],
):
assert identifier is None or isinstance(
identifier, str
), "identifier must be a None or a string"
assert isinstance(
operands, (list, tuple)
), f"invoke_subgraph operands must be a list or tuple of tensors and SymInts {operands}"
assert all(
isinstance(o, (torch.Tensor, torch.SymInt)) for o in operands
), f"invoke_subgraph operands must be a list of tensors and SymInts {operands}"
return super().__call__(subgraph, identifier, operands)
invoke_subgraph = InvokeSubgraphHOP()
def trace_joint_graph(fn, fw_inputs, fw_outputs):
"""
Naively trace out a joint graph. This simplifies the reconstruction of joint
graph in the min-cut partitioner later on.
"""
from torch._functorch.aot_autograd import create_joint
dummy_aot_config = get_dummy_aot_autograd_config()
def joint_fn(*primals_and_tangents):
primals = primals_and_tangents[: len(fw_inputs)]
tangents = primals_and_tangents[len(fw_inputs) :]
fw_outs, grads = create_joint(
prepare_fw_with_masks(fn), aot_config=dummy_aot_config
)(primals, tangents)
maybe_clone = clone_outputs_aliasing_inputs(primals_and_tangents)
return pytree.tree_map(maybe_clone, list(fw_outs) + grads)
primals = list(fw_inputs)
# This assumes that the tangent strides match fw_outputs strides. Check the
# InvokeSubgraphAutogradOp backward op for the contiguous call.
tangents = [_from_fun(out) for out in fw_outputs]
joint_operands = primals + tangents
return _maybe_reenter_make_fx(joint_fn)(*joint_operands)
def create_fw_bw_graph(subgraph, operands):
with suspend_functionalization(), disable_functional_mode():
with disable_proxy_modes_tracing():
# args are functional tensors, generate some example tensors
fw_inputs = pytree.tree_map(_from_fun, operands)
fw_outputs = pytree.tree_map(_from_fun, subgraph(*fw_inputs))
if any(
not isinstance(out, torch.Tensor)
for out in fw_outputs
if out is not None
):
raise RuntimeError(
"Expect outputs of invoke_subgraph to only contains tensors or None. "
f"Got types {[type(out) for out in fw_outputs]}."
)
# Trace the forward subgraph
fw_graph = _maybe_reenter_make_fx(subgraph)(*fw_inputs)
# Trace the joint graph and assign it to the bwd graph
bw_graph = trace_joint_graph(
subgraph,
fw_inputs,
fw_outputs,
)
return fw_graph, bw_graph, len(fw_outputs)
class InvokeSubgraphAutogradOp(torch.autograd.Function):
"""
This autograd function op is to stash the backward graph in the ctx while
running forward.
"""
@staticmethod
def forward(ctx, fw_graph, bw_graph, identifier, num_fw_outs, *operands):
ctx._fw_graph = fw_graph
ctx._bw_graph = bw_graph
ctx._identifier = identifier
ctx._num_fw_outs = num_fw_outs
with torch._C._AutoDispatchBelowAutograd():
out = invoke_subgraph(
fw_graph,
f"___forward_{identifier}",
operands,
)
ctx.save_for_backward(*operands)
return out
@staticmethod
def backward(ctx, *grad_outs):
bw_graph = ctx._bw_graph
identifier = ctx._identifier
primals = ctx.saved_tensors
num_fw_outs = ctx._num_fw_outs
# While tracing we made the assumption that tangents are contiguous. So,
# force the grad_outs to be contiguous.
contiguous_grad_outs = tuple([o.contiguous() for o in grad_outs])
# bw_graph is a joint graph with signature (*primals_and_tangents) and
# returns (*fw_outs_and_grads). To get the grads, we use the num_fw_outs
# to extract the grads.
primals_and_tangents = primals + contiguous_grad_outs
grads = invoke_subgraph(
bw_graph, f"___backward_{identifier}", primals_and_tangents
)[num_fw_outs:]
return None, None, None, None, *grads
@invoke_subgraph.py_impl(DispatchKey.CompositeExplicitAutograd)
def _(subgraph, identifier, operands):
from torch.utils._python_dispatch import _get_current_dispatch_mode
mode = _get_current_dispatch_mode()
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
return subgraph(*operands)
@invoke_subgraph.py_impl(DispatchKey.Autograd)
def _(subgraph, identifier, operands):
if not torch.is_grad_enabled():
with torch._C._AutoDispatchBelowAutograd():
return invoke_subgraph(subgraph, identifier, operands)
# A shortcut for the case where all inputs don't require gradient,
# we skip tracing the forward and backward graph.
if pytree.tree_all_only(
torch.Tensor,
lambda t: not t.requires_grad, # type: ignore[union-attr]
operands,
):
with torch._C._AutoDispatchBelowAutograd():
return invoke_subgraph(subgraph, identifier, operands)
fw_graph, bw_graph, num_fw_outs = create_fw_bw_graph(subgraph, operands)
# TODO(anijain2305) - Implement caching of autograd function op.
return InvokeSubgraphAutogradOp.apply(
fw_graph, bw_graph, identifier, num_fw_outs, *operands
)
@invoke_subgraph.py_functionalize_impl
def _(ctx, subgraph, identifier, operands):
unwrapped_operands = ctx.unwrap_tensors(operands)
with ctx.redispatch_to_next() as m:
# NB: There is an assumption that subgraph does not mutate inputs and
# there is no aliasing. Its Dynamo responsibility to prevent formation
# of invoke_subgraph ops if input aliasing/mutation is detected.
functionalized_subgraph = ctx.functionalize(subgraph)
out = invoke_subgraph(functionalized_subgraph, identifier, unwrapped_operands)
return ctx.wrap_tensors(out)
@invoke_subgraph.py_impl(FakeTensorMode)
def _(mode, subgraph, identifier, operands):
# TODO(anijain2305) - Implement fake tensor caching.
return subgraph(*operands)
@invoke_subgraph.py_impl(ProxyTorchDispatchMode)
def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, operands):
# TODO(anijain2305) - Implement proxy tensor caching.
example_out = invoke_subgraph(subgraph, identifier, operands)
graph = reenter_make_fx(subgraph)(*operands)
assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
qualname = proxy_mode.tracer.get_fresh_qualname("repeated_subgraph")
proxy_mode.tracer.root.register_module(qualname, graph)
node_args = (graph, identifier, operands)
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", invoke_subgraph, proxy_args, {}
)
return track_tensor_tree(
example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
)

View file

@ -444,3 +444,17 @@ def saved_tensors_and_symints(ctx):
s_idx += 1
assert t_idx + s_idx == len(ctx.pos)
return tuple(args)
def get_dummy_aot_autograd_config():
from torch._functorch.aot_autograd import AOTConfig
return AOTConfig(
fw_compiler=None, # type: ignore[arg-type]
bw_compiler=None, # type: ignore[arg-type]
partition_fn=None, # type: ignore[arg-type]
decompositions={},
num_params_buffers=0,
aot_id=0,
keep_inference_input_mutations=False,
)

View file

@ -102,6 +102,17 @@ def simple_cond(x):
return torch.cond(x.sum() > 2, lambda x: (x.cos(),), lambda x: (x.sin(),), [x])
def sample_inputs_invoke_subgraph(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)
yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2))
def simple_invoke_subgraph(x):
def fn(x):
return (torch.sin(x),)
return torch._higher_order_ops.invoke_subgraph(fn, None, (x,))
def sample_inputs_auto_functionalize(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(
make_tensor, device=device, dtype=dtype, requires_grad=False
@ -151,6 +162,21 @@ def simple_while_loop(iter_t, x):
hop_db = [
OpInfo(
name="invoke_subgraph",
variant_test_name="simple",
op=simple_invoke_subgraph,
sample_inputs_func=sample_inputs_invoke_subgraph,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
supports_autograd=True,
# "torch.compile with aot_autograd does not currently support double backward."
supports_gradgrad=False,
),
OpInfo(
name="map",
variant_test_name="simple",