[ca] fix flex attention backward HOP capture in initial graph (#143155)

FIXES https://github.com/pytorch/pytorch/issues/142313

So with previous HOPs, compiled autograd could just inline into their body and get their post-dispatch aten representation. You can't do that with this flex attention HOP, which just wants any proxy tracing mechanism to insert it into its graph. Okay, compiled autograd does use proxy tracing, so we can do that.

This is safe because other than the reenter_make_fx call, there were no other make_fx internals usage in the HOP. And compiled autograd specializes on the AOT backward's saved symints which should cover any changes in shapes to the inputs of the HOP.

However, there's still an issue: Dynamo doesn't know how to handle `FlexAttentionBackwardHOP` and will graph break, so the flex attention backward is running in eager as of this PR. The tlparse looks really scuffed after the compiled autograd capture: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpMMHBEH/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143155
Approved by: https://github.com/drisspg
This commit is contained in:
Simon Fan 2024-12-12 18:35:07 -08:00 committed by PyTorch MergeBot
parent b4f4c75e19
commit 72fd7abb35
2 changed files with 33 additions and 3 deletions

View file

@ -24,6 +24,7 @@ from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.utils import counters
from torch._inductor import config as inductor_config
from torch._inductor.test_case import run_tests, TestCase
from torch.nn.attention.flex_attention import flex_attention
from torch.testing._internal.common_utils import (
scoped_load_inline,
skipIfWindows,
@ -3216,6 +3217,31 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
self.assertTrue(isinstance(view_nodes[0].args[1][0], torch.fx.Node))
self.assertTrue(isinstance(view_nodes[1].args[1][0], torch.fx.Node))
@unittest.skipIf(not HAS_CUDA, "requires cuda")
def test_flex_attention(self):
def fn():
@torch.compile(backend="aot_eager")
def fwd_bwd(x: torch.Tensor):
flex_attention(x, x, x).sum().backward()
for a, b in zip([12, 24, 48], [64, 128, 256]):
v = torch.zeros(
1,
1,
a * b,
b,
dtype=torch.bfloat16,
device="cuda",
requires_grad=True,
)
fwd_bwd(v)
yield v.grad
# TODO: Dynamo graph breaks on torch.ops.higher_order.flex_attention_backward
self.check_output_and_recompiles(
fn, count=3, compiler_fn=make_compiler_fn(fullgraph=False)
)
@unittest.expectedFailure
def test_saved_tensor_unpack_hook_ordering(self):
# not the correct behaviour, I'm just preventing this from changing silently

View file

@ -7,6 +7,7 @@ from torch import Tensor
from torch._C import DispatchKey
from torch._higher_order_ops.utils import (
_has_potential_branch_input_mutation,
_maybe_reenter_make_fx,
autograd_not_implemented,
reenter_make_fx,
save_tensors_and_symints_for_backward,
@ -945,11 +946,14 @@ def trace_flex_attention_backward(
mask_example_vals = [query.new_zeros((), dtype=torch.int) for _ in range(4)]
mask_graph = block_mask[-1]
with TransformGetItemToIndex():
fw_graph = reenter_make_fx(fw_graph)(*fw_example_vals, *score_mod_other_buffers)
joint_graph = reenter_make_fx(joint_graph)(
# There's no active make_fx during the compiled autograd graph's initial capture
fw_graph = _maybe_reenter_make_fx(fw_graph)(
*fw_example_vals, *score_mod_other_buffers
)
joint_graph = _maybe_reenter_make_fx(joint_graph)(
*bw_example_vals, *score_mod_other_buffers
)
mask_graph = reenter_make_fx(mask_graph)(
mask_graph = _maybe_reenter_make_fx(mask_graph)(
*mask_example_vals, *mask_mod_other_buffers
)
assert isinstance(proxy_mode.tracer, torch.fx.Tracer)