[effects] Add inductor support for tokens (#122347)

Given the following code/dynamo graph:
```
class GraphModule(torch.nn.Module):
    def forward(self, L_x_ : torch.Tensor):
        l_x_ = L_x_
        _print = torch.ops.aten._print('moo')
        res = l_x_ + l_x_;  l_x_ = None
        _print_1 = torch.ops.aten._print('moo')
        return (res,)
```

AOTAutograd will trace the following program, threading tokens from the inputs, through the effectful operator calls (torch.ops.aten._print), and as an output:
```
class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "f32[0]", arg1_1: "f32[2, 3]"):
        with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops.aten._print.default, 'moo');  arg0_1 = None
        getitem: "f32[0]" = with_effects[0];  with_effects = None
        add: "f32[2, 3]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1);  arg1_1 = None
        with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo');  getitem = None
        getitem_2: "f32[0]" = with_effects_1[0];  with_effects_1 = None
        return (getitem_2, add)
```
However when we get to inductor, since we want the inductor generated code to not have any token inputs/outputs for better readability, we want to modify the aten graph by removing the tokens from inputs, and creating them through `torch.ops.aten._make_dep_token`, and sinking them through the `torch.ops.aten._sink_tokens` operators.
This has to be done *after* the partitioner, otherwise the partitioner will add the make_token/sink_token operators to the backwards graph.
```
class <lambda>(torch.nn.Module):
   def forward(self, arg1_1: "f32[2, 3]"):
       _make_dep_token_default: "f32[0]" = torch.ops.aten._make_dep_token.default()
       with_effects = torch._higher_order_ops.effects.with_effects(_make_dep_token_default, torch.ops.aten._print.default, 'moo');  _make_dep_token_default = None
       getitem: "f32[0]" = with_effects[0];  with_effects = None
       add: "f32[2, 3]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1);  arg1_1 = None
       with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo');  getitem = None
       getitem_2: "f32[0]" = with_effects_1[0];  with_effects_1 = None
       _sink_tokens_default = torch.ops.aten._sink_tokens.default((getitem_2,));  getitem_2 = None
       return (add,)
```
When doing inductor lowering, we convert `with_effects` calls to an `EffectfulKernel`, which just a `FallbackKernel` but with a pointer to previous effectful operator's call. During scheduling, we will create a `StarDep` between the EffectfulKernel and its previous EffectfulKernel so that they don't get reordered. The inductor generated python code looks like:
```
def call(args):
    arg1_1, = args
    args.clear()
    assert_size_stride(arg1_1, (2, 3), (3, 1))
    # Source Nodes: [_print], Original ATen: []
    buf2 = aten._print.default('moo')
    # Source Nodes: [_print_1], Original ATen: []
    buf3 = aten._print.default('moo')
    buf4 = empty_strided_cpu((2, 3), (3, 1), torch.float32)
    cpp_fused_add_0(arg1_1, buf4)
    del arg1_1
    return (buf4, )
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122347
Approved by: https://github.com/bdhirsh
This commit is contained in:
angelayi 2024-04-08 14:40:03 -07:00 committed by PyTorch MergeBot
parent 9bd6d6e8b0
commit 493478db4a
18 changed files with 289 additions and 48 deletions

View file

@ -3,6 +3,7 @@ import unittest
import torch
import torch._dynamo
import torch._functorch
import torch._inductor
import torch._inductor.decomposition
from torch._functorch.aot_autograd import aot_export_module
@ -17,7 +18,6 @@ from torch.testing._internal.common_utils import (
IS_SANDCASTLE,
IS_WINDOWS,
run_tests,
skipIfTorchDynamo,
TestCase,
)
@ -70,6 +70,22 @@ def forward(self, arg0_1, arg1_1):
self.assertEqual(len(gs.input_tokens), 1)
self.assertEqual(len(gs.output_tokens), 1)
with torch._functorch.config.patch(unlift_effect_tokens=True):
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
self.assertExpectedInline(
str(gm.code).strip(),
"""\
def forward(self, arg1_1):
_make_token_default = torch.ops.prims._make_token.default()
with_effects = torch._higher_order_ops.effects.with_effects(_make_token_default, torch.ops.aten._print.default, 'moo'); _make_token_default = None
getitem = with_effects[0]; with_effects = None
add = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
getitem_2 = with_effects_1[0]; with_effects_1 = None
_sink_tokens_default = torch.ops.prims._sink_tokens.default((getitem_2,)); getitem_2 = None
return (add,)""", # noqa: B950
)
def test_torchbind_custom_op(self):
class M(torch.nn.Module):
def __init__(self):
@ -174,12 +190,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
res = torch.compile(f, backend="aot_eager")(*inputs)
self.assertTrue(torch.allclose(res, f(*inputs)))
@skipIfTorchDynamo(
"We're testing if the test works with inductor, which it currently"
"doesn't, so we expectedFailure-d the test, but the Dynamo tests"
"override the backend, causing an unexpected success"
)
@unittest.expectedFailure # NYI: AssertionError: with_effects is not an OpOverload
@unittest.skipIf(IS_WINDOWS, "triton")
def test_compile_inductor(self):
def f(x):
torch.ops.aten._print("moo")

View file

@ -9,6 +9,7 @@ from functorch.compile import min_cut_rematerialization_partition
import torch
from torch import _guards
from torch._functorch import config as functorch_config
from torch._functorch.compilers import ts_compile
from .common import aot_autograd
from .registry import register_debug_backend as register_backend
@ -76,26 +77,32 @@ register_backend(
name="aot_eager_default_partitioner", compiler_fn=aot_eager_default_partitioner
)
# Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs
# inductor problems.
# aot_eager_decomp_partition just replaces the inductor compiler with nop to help
# isolate inductor vs aot_eager errors
aot_eager_decomp_partition = aot_autograd(
# these are taken from memory_efficient_fusion()
fw_compiler=boxed_nop,
bw_compiler=boxed_nop,
# NB: lambda here is to delay import of inductor
decompositions=lambda: import_module(
"torch._inductor.compile_fx"
).select_decomp_table(),
partition_fn=functools.partial(
min_cut_rematerialization_partition, compiler="inductor"
),
)
def aot_eager_decomp_partition(gm, fake_tensor_inputs):
with functorch_config.patch(unlift_effect_tokens=True):
return aot_autograd(
# these are taken from memory_efficient_fusion()
fw_compiler=boxed_nop,
bw_compiler=boxed_nop,
# NB: lambda here is to delay import of inductor
decompositions=lambda: import_module(
"torch._inductor.compile_fx"
).select_decomp_table(),
partition_fn=functools.partial(
min_cut_rematerialization_partition, compiler="inductor"
),
)(gm, fake_tensor_inputs)
register_backend(
name="aot_eager_decomp_partition", compiler_fn=aot_eager_decomp_partition
)
# AOT Autograd with torchscript backend. Default partitioner.
# aot_ts uses torchscript backend. We can use this with both nnc and nvfuser
# by using the relevant fuser with torch.jit.fuser(...)

View file

@ -14,6 +14,7 @@ from torch._logging import getArtifactLogger, trace_structured
from torch._subclasses.functional_tensor import FunctionalTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from .. import config
from .functional_utils import (
assert_functional_graph,
propagate_input_mutation_stacktraces,
@ -26,6 +27,7 @@ from .traced_function_transforms import (
fn_input_mutations_to_outputs,
fn_prepped_for_autograd,
)
from .utils import unlift_tokens
aot_graphs_log = getArtifactLogger(__name__, "aot_graphs")
@ -102,6 +104,14 @@ def aot_dispatch_base_graph(
copy_count2 = assert_functional_graph(fw_module.graph)
propagate_input_mutation_stacktraces(fw_module.graph)
# See Note [Side-Effectful Tokens in AOTAutograd]
num_tokens = len(fw_metadata.tokens)
if num_tokens != 0 and config.unlift_effect_tokens:
unlift_tokens(fw_module, fw_metadata)
updated_flat_args_subclasses_desugared = updated_flat_args_subclasses_desugared[
num_tokens:
]
assert copy_count == copy_count2
if aot_config.enable_log:

View file

@ -56,6 +56,7 @@ from .utils import (
make_boxed_func,
normalize_as_list,
strict_zip,
unlift_tokens,
)
zip = strict_zip
@ -270,20 +271,25 @@ def aot_dispatch_autograd(
fw_metadata, inner_meta
)
)
num_tokens = len(fw_metadata.tokens)
num_mutated_inp_runtime_indices = len(mutated_inp_runtime_indices)
num_inner_fwd_outputs = (
num_mutated_inp_runtime_indices
+ inner_meta.num_outputs
+ inner_meta.num_intermediate_bases
+ inner_meta.num_outputs_rng_offset
+ len(
fw_metadata.tokens
) # See Note [Side-Effectful Tokens in AOTAutograd]
+ num_tokens # See Note [Side-Effectful Tokens in AOTAutograd]
)
fw_module, bw_module = aot_config.partition_fn(
fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs
)
# See Note [Side-Effectful Tokens in AOTAutograd]
if num_tokens != 0 and config.unlift_effect_tokens:
unlift_tokens(fw_module, fw_metadata)
num_inner_fwd_outputs -= num_tokens
joint_inputs = (joint_inputs[0][num_tokens:], joint_inputs[1])
fw_outs = next(iter(fw_module.graph.find_nodes(op="output"))).args[0]
# we only need to bookkeep the symints that are saved for bw, not any symints
# the user forward might have returned in its own output

View file

@ -69,15 +69,16 @@ def create_runtime_wrapper(
keep_input_mutations: bool,
disable_amp: bool,
):
num_tokens = len(runtime_metadata.tokens)
if not hasattr(compiled_fn, "_boxed_call"):
compiled_fn = make_boxed_func(compiled_fn)
def runtime_wrapper(*args):
# Pass in effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
if num_tokens > 0:
args = (*[torch.empty(0)] * num_tokens, *args)
num_tokens = len(runtime_metadata.tokens)
if config.unlift_effect_tokens:
assert num_tokens == 0
elif num_tokens > 0:
# Pass in effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
args = ([None] * num_tokens, *args)
if trace_joint:
args_ = list(args)

View file

@ -3,6 +3,7 @@ Contains various utils for AOTAutograd, including those for handling collections
"""
import dataclasses
import operator
import warnings
from contextlib import nullcontext
from functools import wraps
@ -224,3 +225,58 @@ def maybe_to_fresh_input(idx, t, meta):
# sees the tensor before the metadata mutation
return t.view(t.shape)
return t
def unlift_tokens(fw_module, fw_metadata):
# Remove the tokens from the inputs/outputs of the graph since inductor does
# not want these extra inputs/outputs, and replace them with
# _make_token() to create a token, and _sink_tokens() to collect the
# tokens. See Note [Side-Effectful Tokens in AOTAutograd]
num_tokens = len(fw_metadata.tokens)
input_token_nodes = []
for i, node in enumerate(fw_module.graph.nodes):
if i < num_tokens:
assert node.op == "placeholder"
input_token_nodes.append(node)
elif node.op == "call_function" and node.target.__name__ == "with_effects":
if node.args[0] in input_token_nodes:
with fw_module.graph.inserting_before(node):
new_token_node = fw_module.graph.call_function(
torch.ops.prims._make_token.default, ()
)
new_token_node.meta["val"] = torch.tensor([])
new_token_node.meta["tensor_meta"] = torch.tensor([])
args = list(node.args)
args[0] = new_token_node
node.args = tuple(args)
elif node.op == "output":
output_token_nodes = node.args[0][:num_tokens]
other_output_args = node.args[0][num_tokens:]
for output_token_node in output_token_nodes:
assert (
output_token_node.op == "call_function"
and output_token_node.target == operator.getitem
and output_token_node.args[1] == 0
)
with fw_module.graph.inserting_before(node):
sink_token_node = fw_module.graph.call_function(
torch.ops.prims._sink_tokens.default,
(output_token_nodes,),
)
node.args = (other_output_args,)
for input_token_node in input_token_nodes:
fw_module.graph.erase_node(input_token_node)
fw_module.recompile()
# This is sad, but we need to update the metadata to get rid of
# the tokens.
fw_metadata.num_forward_returns -= num_tokens
fw_metadata.num_forward -= num_tokens
fw_metadata.tokens = {}

View file

@ -399,6 +399,22 @@ AOT_COUNTER = itertools.count()
# So the signature of the graph input would look something like
# (*tokens, *params_buffers, *user_inputs), and the signature of the graph
# output would look something like (*tokens, *outputs).
#
# However, Inductor does not want the concept of tokens in the final generated
# code's input and output. Since changing the graph signature inside of inductor
# is difficult, after generating the forward graph, we will run a pass to
# remove the tokens from the inputgenerate the following graph for Inductor, where
# the tokens are created and sunk within the graph, rather than as inputs and
# outputs:
#
# def gm(self, reader):
# token0 = torch.ops.prims._make_token()
# token1, frame = with_token(ordered_effect_op, (reader,), token0)
# frame = frame * 2
# token2, frame2 = with_token(ordered_effect_op, (reader,), token1)
# frame2 = frame2 * 2
# sink_token = torch.ops.prims._sink_tokens([token2])
# return frame, frame2
#
#

View file

@ -69,6 +69,13 @@ aggressive_recomputation = False
# is to turn it off during torch.compile.
fake_tensor_allow_unsafe_data_ptr_access = True
# Unlifts effect tokens from the inputs/outputs in the traced graph and instead
# inserts make_token/sink_token calls in the graph to create tokens and then
# sink them at the end. Note that this means the graph is no longer functional
# which may lead to silent errors unless the backend knows how to handle the
# tokens.
unlift_effect_tokens = False
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View file

@ -38,6 +38,7 @@ from torch._dynamo.utils import (
lazy_format_graph_code,
optimus_scuba_log,
)
from torch._functorch import config as functorch_config
from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
from torch._inductor.codecache import code_hash, CompiledFxGraph, FxGraphCache
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
@ -1374,9 +1375,13 @@ def compile_fx(
)
if V.aot_compilation is True:
gm, graph_signature = aot_export_module(
model_, example_inputs_, trace_joint=False, decompositions=decompositions
)
with functorch_config.patch(unlift_effect_tokens=True):
gm, graph_signature = aot_export_module(
model_,
example_inputs_,
trace_joint=False,
decompositions=decompositions,
)
unlifted_gm = _unlift_graph(model_, gm, graph_signature)
if "dynamo_flat_name_to_original_fqn" in model_.meta:
unlifted_gm.meta["dynamo_flat_name_to_original_fqn"] = model_.meta[
@ -1387,7 +1392,7 @@ def compile_fx(
with V.set_fake_mode(fake_mode), torch._guards.tracing(
tracing_context
), compiled_autograd.disable():
), compiled_autograd.disable(), functorch_config.patch(unlift_effect_tokens=True):
return aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,

View file

@ -16,6 +16,7 @@ import torch._logging
import torch.fx
from torch._decomp import get_decompositions
from torch._dynamo.utils import defake, dynamo_timed
from torch._higher_order_ops.effects import _EffectType
from torch._logging import LazyString, trace_structured
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental._backward_state import BackwardState
@ -333,6 +334,8 @@ class GraphLowering(torch.fx.Interpreter):
)
self.init_backend_registration()
self.effectful_ops: Dict[_EffectType, ir.Buffer] = {}
@staticmethod
def decide_layout_opt(gm, *, is_inference) -> bool:
"""
@ -635,7 +638,10 @@ class GraphLowering(torch.fx.Interpreter):
self.buffers.append(buffer)
self.name_to_buffer[name] = buffer
# Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144
if not isinstance(buffer, ir.ComputedBuffer) or not buffer.is_zero_elements():
if (
not (isinstance(buffer, ir.ComputedBuffer) and buffer.is_zero_elements())
and buffer.get_device() is not None
):
self.add_device_info(buffer.get_device())
return name
@ -911,6 +917,7 @@ class GraphLowering(torch.fx.Interpreter):
sympy.Expr,
sympy.logic.boolalg.Boolean,
int,
ir.EffectfulKernel,
),
)
for x in result

View file

@ -123,7 +123,9 @@ def validate_ir(node_or_nodes):
def _check_tensorbox(nodes):
# Could expand this to check deeper properties
# (e.g. TensorBox points to View or StorageBox)
if isinstance(nodes, (list, tuple)):
if nodes is None:
pass
elif isinstance(nodes, (list, tuple)):
for node in nodes:
_check_tensorbox(node)
elif isinstance(nodes, dict):
@ -139,6 +141,7 @@ def validate_ir(node_or_nodes):
TensorBox,
sympy.logic.boolalg.Boolean,
Expr,
EffectfulKernel,
),
), f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]"
@ -5377,16 +5380,26 @@ class FallbackKernel(ExternKernelAlloc):
unflatten_args,
) = cls.process_kernel(kernel, *args, **kwargs)
device = cls.find_device(tensor_args, example_output)
assert device, "Not sure where to find device info"
if example_output is None:
packed = cls(
NoneLayout(None),
kernel,
tensor_args,
non_tensor_args,
unflatten_args,
)
packed = cls(
MultiOutputLayout(device),
kernel,
tensor_args,
non_tensor_args,
unflatten_args,
)
else:
device = cls.find_device(tensor_args, example_output)
assert device, "Not sure where to find device info"
packed = cls(
MultiOutputLayout(device),
kernel,
tensor_args,
non_tensor_args,
unflatten_args,
)
def generate_output(output, indices):
if isinstance(output, (list, tuple)):
@ -7343,6 +7356,47 @@ class WhileLoop(ExternKernel):
wrapper.codegen_while_loop(self)
class EffectfulKernel(FallbackKernel):
def __init__(
self,
layout,
kernel,
tensor_args,
nontensor_args,
unflatten_args,
kwargs=None,
):
super().__init__(
NoneLayout(layout.device),
kernel,
tensor_args,
nontensor_args,
unflatten_args,
kwargs=None,
)
from torch._higher_order_ops.effects import get_effect_key
effect_type = get_effect_key(kernel, (*nontensor_args, *tensor_args), kwargs)
assert effect_type is not None
self.effect_type = effect_type
self.prev_effect_buffer = V.graph.effectful_ops.get(effect_type, None)
V.graph.effectful_ops[effect_type] = self
def get_read_writes(self):
read_writes = super().get_read_writes()
if self.prev_effect_buffer is not None:
read_writes.reads.add(
dependencies.StarDep(self.prev_effect_buffer.get_name())
)
return read_writes
def has_side_effects(self):
return True
class InterpreterShim(torch.fx.Interpreter):
@staticmethod
@functools.lru_cache(None)

View file

@ -207,7 +207,8 @@ def transform_args(args, broadcast, type_promotion_kind, convert_input_to_bool):
promoting_args = [
a
for a in args
if isinstance(a, (Number, sympy.Expr)) or hasattr(a, "dtype")
if isinstance(a, (Number, sympy.Expr))
or getattr(a, "dtype", None) is not None
]
dtype = get_promoted_dtype(
*promoting_args, type_promotion_kind=type_promotion_kind
@ -6110,6 +6111,31 @@ def templated_attention(*args, **kwargs):
raise ValueError("TemplatedAttention was passed a subgraph with no output node!")
@register_lowering(torch.ops.prims._sink_tokens.default)
def _sink_tokens(tokens):
return None
@register_lowering(torch.ops.higher_order.with_effects)
def with_effects(token, op, *args, **kwargs):
result = ir.EffectfulKernel.create(op, *args, **kwargs)
from torch._higher_order_ops.effects import get_effect_key
effect_type = get_effect_key(op, args, kwargs)
assert effect_type is not None
effectful_kernel = V.graph.effectful_ops[effect_type]
if result is None:
return (effectful_kernel,)
result = pytree.tree_map_only(ir.MultiOutput, TensorBox.create, result)
if not isinstance(result, (list, tuple)):
return (effectful_kernel, result)
else:
return (effectful_kernel, *result)
try:
import torch.distributed._functional_collectives

View file

@ -2440,8 +2440,9 @@ class Scheduler:
self.enter_context(node)
if not isinstance(node, NopKernelSchedulerNode):
device = node.get_device()
if not isinstance(node, NopKernelSchedulerNode) and (
device := node.get_device()
):
if (
device != self.current_device
or node.is_extern()
@ -2484,7 +2485,7 @@ class Scheduler:
if not isinstance(node, NopKernelSchedulerNode):
device = node.get_device()
if self.get_backend(device).ready_to_flush():
if device is not None and self.get_backend(device).ready_to_flush():
self.flush()
if self.current_device and device_need_guard(self.current_device.type):

View file

@ -614,7 +614,7 @@ def make_dep_token(
pin_memory=None,
memory_format=None,
):
return torch.empty([], device="meta")
return torch.empty(0, device="meta")
@register_meta(aten.sym_constrain_range.default)

View file

@ -212,6 +212,11 @@ __all__ = [
"fft_r2c",
"fft_c2c",
"fft_c2r",
#
# prims for making/sinking tokens
#
"_make_token",
"_sink_tokens",
]
@ -3027,5 +3032,32 @@ frexp = _make_prim(
doc="",
)
def _make_token_aten() -> TensorLikeType:
return torch.empty(0)
_make_token = _make_prim(
schema="_make_token() -> Tensor",
meta=_make_token_aten,
return_type=RETURN_TYPE.NEW,
impl_aten=_make_token_aten,
doc="Creates a token used for keeping track of side effects.",
)
def _sink_tokens_aten(tokens) -> None:
pass
_sink_tokens = _make_prim(
schema="_sink_tokens(Tensor[] tokens) -> ()",
meta=_sink_tokens_aten,
return_type=RETURN_TYPE.NONE,
impl_aten=_sink_tokens_aten,
doc="Sink all of the tokens which were previously used for keeping track of side effects.",
)
register_rng_prims()
register_debug_prims()

View file

@ -1344,6 +1344,7 @@ class RETURN_TYPE(Enum):
NEW = (0,)
VIEW = (1,)
INPLACE = (2,)
NONE = (3,)
# TODO: when NumberType contains the sym types, can simplify this

View file

@ -1489,7 +1489,7 @@ torch::utils::maybe_initialize_device(options);
# we're an output-arg variant, check these args against output tensor
if not f.func.is_out_fn():
raise RuntimeError(
f"{f.func}: dtype in tensor_options_args without output arg"
f"{f.func}: dtype in tensor_options_args without output arg, {ps} {ps.arguments}"
)
if not all(a in tensor_options_args_names for a in ("layout", "device")):
raise RuntimeError(

View file

@ -81,6 +81,7 @@ FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
"_chunk_grad_outputs_efficient_attention", # returns a bool
"_fused_sdp_choice", # returns an int
"_print", # no return
"_sink_tokens", # no return
"_nested_get_ragged_idx", # returns an int
]