mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[effects] Add inductor support for tokens (#122347)
Given the following code/dynamo graph:
```
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
_print = torch.ops.aten._print('moo')
res = l_x_ + l_x_; l_x_ = None
_print_1 = torch.ops.aten._print('moo')
return (res,)
```
AOTAutograd will trace the following program, threading tokens from the inputs, through the effectful operator calls (torch.ops.aten._print), and as an output:
```
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[0]", arg1_1: "f32[2, 3]"):
with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops.aten._print.default, 'moo'); arg0_1 = None
getitem: "f32[0]" = with_effects[0]; with_effects = None
add: "f32[2, 3]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
getitem_2: "f32[0]" = with_effects_1[0]; with_effects_1 = None
return (getitem_2, add)
```
However when we get to inductor, since we want the inductor generated code to not have any token inputs/outputs for better readability, we want to modify the aten graph by removing the tokens from inputs, and creating them through `torch.ops.aten._make_dep_token`, and sinking them through the `torch.ops.aten._sink_tokens` operators.
This has to be done *after* the partitioner, otherwise the partitioner will add the make_token/sink_token operators to the backwards graph.
```
class <lambda>(torch.nn.Module):
def forward(self, arg1_1: "f32[2, 3]"):
_make_dep_token_default: "f32[0]" = torch.ops.aten._make_dep_token.default()
with_effects = torch._higher_order_ops.effects.with_effects(_make_dep_token_default, torch.ops.aten._print.default, 'moo'); _make_dep_token_default = None
getitem: "f32[0]" = with_effects[0]; with_effects = None
add: "f32[2, 3]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
getitem_2: "f32[0]" = with_effects_1[0]; with_effects_1 = None
_sink_tokens_default = torch.ops.aten._sink_tokens.default((getitem_2,)); getitem_2 = None
return (add,)
```
When doing inductor lowering, we convert `with_effects` calls to an `EffectfulKernel`, which just a `FallbackKernel` but with a pointer to previous effectful operator's call. During scheduling, we will create a `StarDep` between the EffectfulKernel and its previous EffectfulKernel so that they don't get reordered. The inductor generated python code looks like:
```
def call(args):
arg1_1, = args
args.clear()
assert_size_stride(arg1_1, (2, 3), (3, 1))
# Source Nodes: [_print], Original ATen: []
buf2 = aten._print.default('moo')
# Source Nodes: [_print_1], Original ATen: []
buf3 = aten._print.default('moo')
buf4 = empty_strided_cpu((2, 3), (3, 1), torch.float32)
cpp_fused_add_0(arg1_1, buf4)
del arg1_1
return (buf4, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122347
Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
9bd6d6e8b0
commit
493478db4a
18 changed files with 289 additions and 48 deletions
|
|
@ -3,6 +3,7 @@ import unittest
|
|||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._functorch
|
||||
import torch._inductor
|
||||
import torch._inductor.decomposition
|
||||
from torch._functorch.aot_autograd import aot_export_module
|
||||
|
|
@ -17,7 +18,6 @@ from torch.testing._internal.common_utils import (
|
|||
IS_SANDCASTLE,
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
|
@ -70,6 +70,22 @@ def forward(self, arg0_1, arg1_1):
|
|||
self.assertEqual(len(gs.input_tokens), 1)
|
||||
self.assertEqual(len(gs.output_tokens), 1)
|
||||
|
||||
with torch._functorch.config.patch(unlift_effect_tokens=True):
|
||||
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
|
||||
self.assertExpectedInline(
|
||||
str(gm.code).strip(),
|
||||
"""\
|
||||
def forward(self, arg1_1):
|
||||
_make_token_default = torch.ops.prims._make_token.default()
|
||||
with_effects = torch._higher_order_ops.effects.with_effects(_make_token_default, torch.ops.aten._print.default, 'moo'); _make_token_default = None
|
||||
getitem = with_effects[0]; with_effects = None
|
||||
add = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
|
||||
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
|
||||
getitem_2 = with_effects_1[0]; with_effects_1 = None
|
||||
_sink_tokens_default = torch.ops.prims._sink_tokens.default((getitem_2,)); getitem_2 = None
|
||||
return (add,)""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_torchbind_custom_op(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -174,12 +190,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
|||
res = torch.compile(f, backend="aot_eager")(*inputs)
|
||||
self.assertTrue(torch.allclose(res, f(*inputs)))
|
||||
|
||||
@skipIfTorchDynamo(
|
||||
"We're testing if the test works with inductor, which it currently"
|
||||
"doesn't, so we expectedFailure-d the test, but the Dynamo tests"
|
||||
"override the backend, causing an unexpected success"
|
||||
)
|
||||
@unittest.expectedFailure # NYI: AssertionError: with_effects is not an OpOverload
|
||||
@unittest.skipIf(IS_WINDOWS, "triton")
|
||||
def test_compile_inductor(self):
|
||||
def f(x):
|
||||
torch.ops.aten._print("moo")
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from functorch.compile import min_cut_rematerialization_partition
|
|||
|
||||
import torch
|
||||
from torch import _guards
|
||||
from torch._functorch import config as functorch_config
|
||||
from torch._functorch.compilers import ts_compile
|
||||
from .common import aot_autograd
|
||||
from .registry import register_debug_backend as register_backend
|
||||
|
|
@ -76,26 +77,32 @@ register_backend(
|
|||
name="aot_eager_default_partitioner", compiler_fn=aot_eager_default_partitioner
|
||||
)
|
||||
|
||||
|
||||
# Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs
|
||||
# inductor problems.
|
||||
# aot_eager_decomp_partition just replaces the inductor compiler with nop to help
|
||||
# isolate inductor vs aot_eager errors
|
||||
aot_eager_decomp_partition = aot_autograd(
|
||||
# these are taken from memory_efficient_fusion()
|
||||
fw_compiler=boxed_nop,
|
||||
bw_compiler=boxed_nop,
|
||||
# NB: lambda here is to delay import of inductor
|
||||
decompositions=lambda: import_module(
|
||||
"torch._inductor.compile_fx"
|
||||
).select_decomp_table(),
|
||||
partition_fn=functools.partial(
|
||||
min_cut_rematerialization_partition, compiler="inductor"
|
||||
),
|
||||
)
|
||||
def aot_eager_decomp_partition(gm, fake_tensor_inputs):
|
||||
with functorch_config.patch(unlift_effect_tokens=True):
|
||||
return aot_autograd(
|
||||
# these are taken from memory_efficient_fusion()
|
||||
fw_compiler=boxed_nop,
|
||||
bw_compiler=boxed_nop,
|
||||
# NB: lambda here is to delay import of inductor
|
||||
decompositions=lambda: import_module(
|
||||
"torch._inductor.compile_fx"
|
||||
).select_decomp_table(),
|
||||
partition_fn=functools.partial(
|
||||
min_cut_rematerialization_partition, compiler="inductor"
|
||||
),
|
||||
)(gm, fake_tensor_inputs)
|
||||
|
||||
|
||||
register_backend(
|
||||
name="aot_eager_decomp_partition", compiler_fn=aot_eager_decomp_partition
|
||||
)
|
||||
|
||||
|
||||
# AOT Autograd with torchscript backend. Default partitioner.
|
||||
# aot_ts uses torchscript backend. We can use this with both nnc and nvfuser
|
||||
# by using the relevant fuser with torch.jit.fuser(...)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from torch._logging import getArtifactLogger, trace_structured
|
|||
from torch._subclasses.functional_tensor import FunctionalTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
from .. import config
|
||||
from .functional_utils import (
|
||||
assert_functional_graph,
|
||||
propagate_input_mutation_stacktraces,
|
||||
|
|
@ -26,6 +27,7 @@ from .traced_function_transforms import (
|
|||
fn_input_mutations_to_outputs,
|
||||
fn_prepped_for_autograd,
|
||||
)
|
||||
from .utils import unlift_tokens
|
||||
|
||||
aot_graphs_log = getArtifactLogger(__name__, "aot_graphs")
|
||||
|
||||
|
|
@ -102,6 +104,14 @@ def aot_dispatch_base_graph(
|
|||
copy_count2 = assert_functional_graph(fw_module.graph)
|
||||
propagate_input_mutation_stacktraces(fw_module.graph)
|
||||
|
||||
# See Note [Side-Effectful Tokens in AOTAutograd]
|
||||
num_tokens = len(fw_metadata.tokens)
|
||||
if num_tokens != 0 and config.unlift_effect_tokens:
|
||||
unlift_tokens(fw_module, fw_metadata)
|
||||
updated_flat_args_subclasses_desugared = updated_flat_args_subclasses_desugared[
|
||||
num_tokens:
|
||||
]
|
||||
|
||||
assert copy_count == copy_count2
|
||||
|
||||
if aot_config.enable_log:
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ from .utils import (
|
|||
make_boxed_func,
|
||||
normalize_as_list,
|
||||
strict_zip,
|
||||
unlift_tokens,
|
||||
)
|
||||
|
||||
zip = strict_zip
|
||||
|
|
@ -270,20 +271,25 @@ def aot_dispatch_autograd(
|
|||
fw_metadata, inner_meta
|
||||
)
|
||||
)
|
||||
num_tokens = len(fw_metadata.tokens)
|
||||
num_mutated_inp_runtime_indices = len(mutated_inp_runtime_indices)
|
||||
num_inner_fwd_outputs = (
|
||||
num_mutated_inp_runtime_indices
|
||||
+ inner_meta.num_outputs
|
||||
+ inner_meta.num_intermediate_bases
|
||||
+ inner_meta.num_outputs_rng_offset
|
||||
+ len(
|
||||
fw_metadata.tokens
|
||||
) # See Note [Side-Effectful Tokens in AOTAutograd]
|
||||
+ num_tokens # See Note [Side-Effectful Tokens in AOTAutograd]
|
||||
)
|
||||
fw_module, bw_module = aot_config.partition_fn(
|
||||
fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs
|
||||
)
|
||||
|
||||
# See Note [Side-Effectful Tokens in AOTAutograd]
|
||||
if num_tokens != 0 and config.unlift_effect_tokens:
|
||||
unlift_tokens(fw_module, fw_metadata)
|
||||
num_inner_fwd_outputs -= num_tokens
|
||||
joint_inputs = (joint_inputs[0][num_tokens:], joint_inputs[1])
|
||||
|
||||
fw_outs = next(iter(fw_module.graph.find_nodes(op="output"))).args[0]
|
||||
# we only need to bookkeep the symints that are saved for bw, not any symints
|
||||
# the user forward might have returned in its own output
|
||||
|
|
|
|||
|
|
@ -69,15 +69,16 @@ def create_runtime_wrapper(
|
|||
keep_input_mutations: bool,
|
||||
disable_amp: bool,
|
||||
):
|
||||
num_tokens = len(runtime_metadata.tokens)
|
||||
|
||||
if not hasattr(compiled_fn, "_boxed_call"):
|
||||
compiled_fn = make_boxed_func(compiled_fn)
|
||||
|
||||
def runtime_wrapper(*args):
|
||||
# Pass in effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
|
||||
if num_tokens > 0:
|
||||
args = (*[torch.empty(0)] * num_tokens, *args)
|
||||
num_tokens = len(runtime_metadata.tokens)
|
||||
if config.unlift_effect_tokens:
|
||||
assert num_tokens == 0
|
||||
elif num_tokens > 0:
|
||||
# Pass in effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
|
||||
args = ([None] * num_tokens, *args)
|
||||
|
||||
if trace_joint:
|
||||
args_ = list(args)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ Contains various utils for AOTAutograd, including those for handling collections
|
|||
"""
|
||||
|
||||
import dataclasses
|
||||
import operator
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from functools import wraps
|
||||
|
|
@ -224,3 +225,58 @@ def maybe_to_fresh_input(idx, t, meta):
|
|||
# sees the tensor before the metadata mutation
|
||||
return t.view(t.shape)
|
||||
return t
|
||||
|
||||
|
||||
def unlift_tokens(fw_module, fw_metadata):
|
||||
# Remove the tokens from the inputs/outputs of the graph since inductor does
|
||||
# not want these extra inputs/outputs, and replace them with
|
||||
# _make_token() to create a token, and _sink_tokens() to collect the
|
||||
# tokens. See Note [Side-Effectful Tokens in AOTAutograd]
|
||||
num_tokens = len(fw_metadata.tokens)
|
||||
|
||||
input_token_nodes = []
|
||||
for i, node in enumerate(fw_module.graph.nodes):
|
||||
if i < num_tokens:
|
||||
assert node.op == "placeholder"
|
||||
input_token_nodes.append(node)
|
||||
|
||||
elif node.op == "call_function" and node.target.__name__ == "with_effects":
|
||||
if node.args[0] in input_token_nodes:
|
||||
with fw_module.graph.inserting_before(node):
|
||||
new_token_node = fw_module.graph.call_function(
|
||||
torch.ops.prims._make_token.default, ()
|
||||
)
|
||||
new_token_node.meta["val"] = torch.tensor([])
|
||||
new_token_node.meta["tensor_meta"] = torch.tensor([])
|
||||
|
||||
args = list(node.args)
|
||||
args[0] = new_token_node
|
||||
node.args = tuple(args)
|
||||
|
||||
elif node.op == "output":
|
||||
output_token_nodes = node.args[0][:num_tokens]
|
||||
other_output_args = node.args[0][num_tokens:]
|
||||
|
||||
for output_token_node in output_token_nodes:
|
||||
assert (
|
||||
output_token_node.op == "call_function"
|
||||
and output_token_node.target == operator.getitem
|
||||
and output_token_node.args[1] == 0
|
||||
)
|
||||
with fw_module.graph.inserting_before(node):
|
||||
sink_token_node = fw_module.graph.call_function(
|
||||
torch.ops.prims._sink_tokens.default,
|
||||
(output_token_nodes,),
|
||||
)
|
||||
node.args = (other_output_args,)
|
||||
|
||||
for input_token_node in input_token_nodes:
|
||||
fw_module.graph.erase_node(input_token_node)
|
||||
|
||||
fw_module.recompile()
|
||||
|
||||
# This is sad, but we need to update the metadata to get rid of
|
||||
# the tokens.
|
||||
fw_metadata.num_forward_returns -= num_tokens
|
||||
fw_metadata.num_forward -= num_tokens
|
||||
fw_metadata.tokens = {}
|
||||
|
|
|
|||
|
|
@ -399,6 +399,22 @@ AOT_COUNTER = itertools.count()
|
|||
# So the signature of the graph input would look something like
|
||||
# (*tokens, *params_buffers, *user_inputs), and the signature of the graph
|
||||
# output would look something like (*tokens, *outputs).
|
||||
#
|
||||
# However, Inductor does not want the concept of tokens in the final generated
|
||||
# code's input and output. Since changing the graph signature inside of inductor
|
||||
# is difficult, after generating the forward graph, we will run a pass to
|
||||
# remove the tokens from the inputgenerate the following graph for Inductor, where
|
||||
# the tokens are created and sunk within the graph, rather than as inputs and
|
||||
# outputs:
|
||||
#
|
||||
# def gm(self, reader):
|
||||
# token0 = torch.ops.prims._make_token()
|
||||
# token1, frame = with_token(ordered_effect_op, (reader,), token0)
|
||||
# frame = frame * 2
|
||||
# token2, frame2 = with_token(ordered_effect_op, (reader,), token1)
|
||||
# frame2 = frame2 * 2
|
||||
# sink_token = torch.ops.prims._sink_tokens([token2])
|
||||
# return frame, frame2
|
||||
|
||||
#
|
||||
#
|
||||
|
|
|
|||
|
|
@ -69,6 +69,13 @@ aggressive_recomputation = False
|
|||
# is to turn it off during torch.compile.
|
||||
fake_tensor_allow_unsafe_data_ptr_access = True
|
||||
|
||||
# Unlifts effect tokens from the inputs/outputs in the traced graph and instead
|
||||
# inserts make_token/sink_token calls in the graph to create tokens and then
|
||||
# sink them at the end. Note that this means the graph is no longer functional
|
||||
# which may lead to silent errors unless the backend knows how to handle the
|
||||
# tokens.
|
||||
unlift_effect_tokens = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ from torch._dynamo.utils import (
|
|||
lazy_format_graph_code,
|
||||
optimus_scuba_log,
|
||||
)
|
||||
from torch._functorch import config as functorch_config
|
||||
from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
|
||||
from torch._inductor.codecache import code_hash, CompiledFxGraph, FxGraphCache
|
||||
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
|
||||
|
|
@ -1374,9 +1375,13 @@ def compile_fx(
|
|||
)
|
||||
|
||||
if V.aot_compilation is True:
|
||||
gm, graph_signature = aot_export_module(
|
||||
model_, example_inputs_, trace_joint=False, decompositions=decompositions
|
||||
)
|
||||
with functorch_config.patch(unlift_effect_tokens=True):
|
||||
gm, graph_signature = aot_export_module(
|
||||
model_,
|
||||
example_inputs_,
|
||||
trace_joint=False,
|
||||
decompositions=decompositions,
|
||||
)
|
||||
unlifted_gm = _unlift_graph(model_, gm, graph_signature)
|
||||
if "dynamo_flat_name_to_original_fqn" in model_.meta:
|
||||
unlifted_gm.meta["dynamo_flat_name_to_original_fqn"] = model_.meta[
|
||||
|
|
@ -1387,7 +1392,7 @@ def compile_fx(
|
|||
|
||||
with V.set_fake_mode(fake_mode), torch._guards.tracing(
|
||||
tracing_context
|
||||
), compiled_autograd.disable():
|
||||
), compiled_autograd.disable(), functorch_config.patch(unlift_effect_tokens=True):
|
||||
return aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import torch._logging
|
|||
import torch.fx
|
||||
from torch._decomp import get_decompositions
|
||||
from torch._dynamo.utils import defake, dynamo_timed
|
||||
from torch._higher_order_ops.effects import _EffectType
|
||||
from torch._logging import LazyString, trace_structured
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.fx.experimental._backward_state import BackwardState
|
||||
|
|
@ -333,6 +334,8 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
)
|
||||
self.init_backend_registration()
|
||||
|
||||
self.effectful_ops: Dict[_EffectType, ir.Buffer] = {}
|
||||
|
||||
@staticmethod
|
||||
def decide_layout_opt(gm, *, is_inference) -> bool:
|
||||
"""
|
||||
|
|
@ -635,7 +638,10 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
self.buffers.append(buffer)
|
||||
self.name_to_buffer[name] = buffer
|
||||
# Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144
|
||||
if not isinstance(buffer, ir.ComputedBuffer) or not buffer.is_zero_elements():
|
||||
if (
|
||||
not (isinstance(buffer, ir.ComputedBuffer) and buffer.is_zero_elements())
|
||||
and buffer.get_device() is not None
|
||||
):
|
||||
self.add_device_info(buffer.get_device())
|
||||
return name
|
||||
|
||||
|
|
@ -911,6 +917,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
sympy.Expr,
|
||||
sympy.logic.boolalg.Boolean,
|
||||
int,
|
||||
ir.EffectfulKernel,
|
||||
),
|
||||
)
|
||||
for x in result
|
||||
|
|
|
|||
|
|
@ -123,7 +123,9 @@ def validate_ir(node_or_nodes):
|
|||
def _check_tensorbox(nodes):
|
||||
# Could expand this to check deeper properties
|
||||
# (e.g. TensorBox points to View or StorageBox)
|
||||
if isinstance(nodes, (list, tuple)):
|
||||
if nodes is None:
|
||||
pass
|
||||
elif isinstance(nodes, (list, tuple)):
|
||||
for node in nodes:
|
||||
_check_tensorbox(node)
|
||||
elif isinstance(nodes, dict):
|
||||
|
|
@ -139,6 +141,7 @@ def validate_ir(node_or_nodes):
|
|||
TensorBox,
|
||||
sympy.logic.boolalg.Boolean,
|
||||
Expr,
|
||||
EffectfulKernel,
|
||||
),
|
||||
), f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]"
|
||||
|
||||
|
|
@ -5377,16 +5380,26 @@ class FallbackKernel(ExternKernelAlloc):
|
|||
unflatten_args,
|
||||
) = cls.process_kernel(kernel, *args, **kwargs)
|
||||
|
||||
device = cls.find_device(tensor_args, example_output)
|
||||
assert device, "Not sure where to find device info"
|
||||
if example_output is None:
|
||||
packed = cls(
|
||||
NoneLayout(None),
|
||||
kernel,
|
||||
tensor_args,
|
||||
non_tensor_args,
|
||||
unflatten_args,
|
||||
)
|
||||
|
||||
packed = cls(
|
||||
MultiOutputLayout(device),
|
||||
kernel,
|
||||
tensor_args,
|
||||
non_tensor_args,
|
||||
unflatten_args,
|
||||
)
|
||||
else:
|
||||
device = cls.find_device(tensor_args, example_output)
|
||||
assert device, "Not sure where to find device info"
|
||||
|
||||
packed = cls(
|
||||
MultiOutputLayout(device),
|
||||
kernel,
|
||||
tensor_args,
|
||||
non_tensor_args,
|
||||
unflatten_args,
|
||||
)
|
||||
|
||||
def generate_output(output, indices):
|
||||
if isinstance(output, (list, tuple)):
|
||||
|
|
@ -7343,6 +7356,47 @@ class WhileLoop(ExternKernel):
|
|||
wrapper.codegen_while_loop(self)
|
||||
|
||||
|
||||
class EffectfulKernel(FallbackKernel):
|
||||
def __init__(
|
||||
self,
|
||||
layout,
|
||||
kernel,
|
||||
tensor_args,
|
||||
nontensor_args,
|
||||
unflatten_args,
|
||||
kwargs=None,
|
||||
):
|
||||
super().__init__(
|
||||
NoneLayout(layout.device),
|
||||
kernel,
|
||||
tensor_args,
|
||||
nontensor_args,
|
||||
unflatten_args,
|
||||
kwargs=None,
|
||||
)
|
||||
|
||||
from torch._higher_order_ops.effects import get_effect_key
|
||||
|
||||
effect_type = get_effect_key(kernel, (*nontensor_args, *tensor_args), kwargs)
|
||||
assert effect_type is not None
|
||||
self.effect_type = effect_type
|
||||
self.prev_effect_buffer = V.graph.effectful_ops.get(effect_type, None)
|
||||
V.graph.effectful_ops[effect_type] = self
|
||||
|
||||
def get_read_writes(self):
|
||||
read_writes = super().get_read_writes()
|
||||
|
||||
if self.prev_effect_buffer is not None:
|
||||
read_writes.reads.add(
|
||||
dependencies.StarDep(self.prev_effect_buffer.get_name())
|
||||
)
|
||||
|
||||
return read_writes
|
||||
|
||||
def has_side_effects(self):
|
||||
return True
|
||||
|
||||
|
||||
class InterpreterShim(torch.fx.Interpreter):
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
|
|
|
|||
|
|
@ -207,7 +207,8 @@ def transform_args(args, broadcast, type_promotion_kind, convert_input_to_bool):
|
|||
promoting_args = [
|
||||
a
|
||||
for a in args
|
||||
if isinstance(a, (Number, sympy.Expr)) or hasattr(a, "dtype")
|
||||
if isinstance(a, (Number, sympy.Expr))
|
||||
or getattr(a, "dtype", None) is not None
|
||||
]
|
||||
dtype = get_promoted_dtype(
|
||||
*promoting_args, type_promotion_kind=type_promotion_kind
|
||||
|
|
@ -6110,6 +6111,31 @@ def templated_attention(*args, **kwargs):
|
|||
raise ValueError("TemplatedAttention was passed a subgraph with no output node!")
|
||||
|
||||
|
||||
@register_lowering(torch.ops.prims._sink_tokens.default)
|
||||
def _sink_tokens(tokens):
|
||||
return None
|
||||
|
||||
|
||||
@register_lowering(torch.ops.higher_order.with_effects)
|
||||
def with_effects(token, op, *args, **kwargs):
|
||||
result = ir.EffectfulKernel.create(op, *args, **kwargs)
|
||||
|
||||
from torch._higher_order_ops.effects import get_effect_key
|
||||
|
||||
effect_type = get_effect_key(op, args, kwargs)
|
||||
assert effect_type is not None
|
||||
effectful_kernel = V.graph.effectful_ops[effect_type]
|
||||
|
||||
if result is None:
|
||||
return (effectful_kernel,)
|
||||
|
||||
result = pytree.tree_map_only(ir.MultiOutput, TensorBox.create, result)
|
||||
if not isinstance(result, (list, tuple)):
|
||||
return (effectful_kernel, result)
|
||||
else:
|
||||
return (effectful_kernel, *result)
|
||||
|
||||
|
||||
try:
|
||||
import torch.distributed._functional_collectives
|
||||
|
||||
|
|
|
|||
|
|
@ -2440,8 +2440,9 @@ class Scheduler:
|
|||
|
||||
self.enter_context(node)
|
||||
|
||||
if not isinstance(node, NopKernelSchedulerNode):
|
||||
device = node.get_device()
|
||||
if not isinstance(node, NopKernelSchedulerNode) and (
|
||||
device := node.get_device()
|
||||
):
|
||||
if (
|
||||
device != self.current_device
|
||||
or node.is_extern()
|
||||
|
|
@ -2484,7 +2485,7 @@ class Scheduler:
|
|||
|
||||
if not isinstance(node, NopKernelSchedulerNode):
|
||||
device = node.get_device()
|
||||
if self.get_backend(device).ready_to_flush():
|
||||
if device is not None and self.get_backend(device).ready_to_flush():
|
||||
self.flush()
|
||||
|
||||
if self.current_device and device_need_guard(self.current_device.type):
|
||||
|
|
|
|||
|
|
@ -614,7 +614,7 @@ def make_dep_token(
|
|||
pin_memory=None,
|
||||
memory_format=None,
|
||||
):
|
||||
return torch.empty([], device="meta")
|
||||
return torch.empty(0, device="meta")
|
||||
|
||||
|
||||
@register_meta(aten.sym_constrain_range.default)
|
||||
|
|
|
|||
|
|
@ -212,6 +212,11 @@ __all__ = [
|
|||
"fft_r2c",
|
||||
"fft_c2c",
|
||||
"fft_c2r",
|
||||
#
|
||||
# prims for making/sinking tokens
|
||||
#
|
||||
"_make_token",
|
||||
"_sink_tokens",
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -3027,5 +3032,32 @@ frexp = _make_prim(
|
|||
doc="",
|
||||
)
|
||||
|
||||
|
||||
def _make_token_aten() -> TensorLikeType:
|
||||
return torch.empty(0)
|
||||
|
||||
|
||||
_make_token = _make_prim(
|
||||
schema="_make_token() -> Tensor",
|
||||
meta=_make_token_aten,
|
||||
return_type=RETURN_TYPE.NEW,
|
||||
impl_aten=_make_token_aten,
|
||||
doc="Creates a token used for keeping track of side effects.",
|
||||
)
|
||||
|
||||
|
||||
def _sink_tokens_aten(tokens) -> None:
|
||||
pass
|
||||
|
||||
|
||||
_sink_tokens = _make_prim(
|
||||
schema="_sink_tokens(Tensor[] tokens) -> ()",
|
||||
meta=_sink_tokens_aten,
|
||||
return_type=RETURN_TYPE.NONE,
|
||||
impl_aten=_sink_tokens_aten,
|
||||
doc="Sink all of the tokens which were previously used for keeping track of side effects.",
|
||||
)
|
||||
|
||||
|
||||
register_rng_prims()
|
||||
register_debug_prims()
|
||||
|
|
|
|||
|
|
@ -1344,6 +1344,7 @@ class RETURN_TYPE(Enum):
|
|||
NEW = (0,)
|
||||
VIEW = (1,)
|
||||
INPLACE = (2,)
|
||||
NONE = (3,)
|
||||
|
||||
|
||||
# TODO: when NumberType contains the sym types, can simplify this
|
||||
|
|
|
|||
|
|
@ -1489,7 +1489,7 @@ torch::utils::maybe_initialize_device(options);
|
|||
# we're an output-arg variant, check these args against output tensor
|
||||
if not f.func.is_out_fn():
|
||||
raise RuntimeError(
|
||||
f"{f.func}: dtype in tensor_options_args without output arg"
|
||||
f"{f.func}: dtype in tensor_options_args without output arg, {ps} {ps.arguments}"
|
||||
)
|
||||
if not all(a in tensor_options_args_names for a in ("layout", "device")):
|
||||
raise RuntimeError(
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
|
|||
"_chunk_grad_outputs_efficient_attention", # returns a bool
|
||||
"_fused_sdp_choice", # returns an int
|
||||
"_print", # no return
|
||||
"_sink_tokens", # no return
|
||||
"_nested_get_ragged_idx", # returns an int
|
||||
]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue