mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[ca] fix flex attention backward HOP capture in initial graph (#143155)
FIXES https://github.com/pytorch/pytorch/issues/142313 So with previous HOPs, compiled autograd could just inline into their body and get their post-dispatch aten representation. You can't do that with this flex attention HOP, which just wants any proxy tracing mechanism to insert it into its graph. Okay, compiled autograd does use proxy tracing, so we can do that. This is safe because other than the reenter_make_fx call, there were no other make_fx internals usage in the HOP. And compiled autograd specializes on the AOT backward's saved symints which should cover any changes in shapes to the inputs of the HOP. However, there's still an issue: Dynamo doesn't know how to handle `FlexAttentionBackwardHOP` and will graph break, so the flex attention backward is running in eager as of this PR. The tlparse looks really scuffed after the compiled autograd capture: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpMMHBEH/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143155 Approved by: https://github.com/drisspg
This commit is contained in:
parent
b4f4c75e19
commit
72fd7abb35
2 changed files with 33 additions and 3 deletions
|
|
@ -24,6 +24,7 @@ from torch._dynamo.device_interface import get_interface_for_device
|
|||
from torch._dynamo.utils import counters
|
||||
from torch._inductor import config as inductor_config
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.nn.attention.flex_attention import flex_attention
|
||||
from torch.testing._internal.common_utils import (
|
||||
scoped_load_inline,
|
||||
skipIfWindows,
|
||||
|
|
@ -3216,6 +3217,31 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
|
|||
self.assertTrue(isinstance(view_nodes[0].args[1][0], torch.fx.Node))
|
||||
self.assertTrue(isinstance(view_nodes[1].args[1][0], torch.fx.Node))
|
||||
|
||||
@unittest.skipIf(not HAS_CUDA, "requires cuda")
|
||||
def test_flex_attention(self):
|
||||
def fn():
|
||||
@torch.compile(backend="aot_eager")
|
||||
def fwd_bwd(x: torch.Tensor):
|
||||
flex_attention(x, x, x).sum().backward()
|
||||
|
||||
for a, b in zip([12, 24, 48], [64, 128, 256]):
|
||||
v = torch.zeros(
|
||||
1,
|
||||
1,
|
||||
a * b,
|
||||
b,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
requires_grad=True,
|
||||
)
|
||||
fwd_bwd(v)
|
||||
yield v.grad
|
||||
|
||||
# TODO: Dynamo graph breaks on torch.ops.higher_order.flex_attention_backward
|
||||
self.check_output_and_recompiles(
|
||||
fn, count=3, compiler_fn=make_compiler_fn(fullgraph=False)
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_saved_tensor_unpack_hook_ordering(self):
|
||||
# not the correct behaviour, I'm just preventing this from changing silently
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from torch import Tensor
|
|||
from torch._C import DispatchKey
|
||||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_mutation,
|
||||
_maybe_reenter_make_fx,
|
||||
autograd_not_implemented,
|
||||
reenter_make_fx,
|
||||
save_tensors_and_symints_for_backward,
|
||||
|
|
@ -945,11 +946,14 @@ def trace_flex_attention_backward(
|
|||
mask_example_vals = [query.new_zeros((), dtype=torch.int) for _ in range(4)]
|
||||
mask_graph = block_mask[-1]
|
||||
with TransformGetItemToIndex():
|
||||
fw_graph = reenter_make_fx(fw_graph)(*fw_example_vals, *score_mod_other_buffers)
|
||||
joint_graph = reenter_make_fx(joint_graph)(
|
||||
# There's no active make_fx during the compiled autograd graph's initial capture
|
||||
fw_graph = _maybe_reenter_make_fx(fw_graph)(
|
||||
*fw_example_vals, *score_mod_other_buffers
|
||||
)
|
||||
joint_graph = _maybe_reenter_make_fx(joint_graph)(
|
||||
*bw_example_vals, *score_mod_other_buffers
|
||||
)
|
||||
mask_graph = reenter_make_fx(mask_graph)(
|
||||
mask_graph = _maybe_reenter_make_fx(mask_graph)(
|
||||
*mask_example_vals, *mask_mod_other_buffers
|
||||
)
|
||||
assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
|
||||
|
|
|
|||
Loading…
Reference in a new issue