diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 8270f0a0e22..356d3dd53da 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -24,6 +24,7 @@ from torch._dynamo.device_interface import get_interface_for_device from torch._dynamo.utils import counters from torch._inductor import config as inductor_config from torch._inductor.test_case import run_tests, TestCase +from torch.nn.attention.flex_attention import flex_attention from torch.testing._internal.common_utils import ( scoped_load_inline, skipIfWindows, @@ -3216,6 +3217,31 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) { self.assertTrue(isinstance(view_nodes[0].args[1][0], torch.fx.Node)) self.assertTrue(isinstance(view_nodes[1].args[1][0], torch.fx.Node)) + @unittest.skipIf(not HAS_CUDA, "requires cuda") + def test_flex_attention(self): + def fn(): + @torch.compile(backend="aot_eager") + def fwd_bwd(x: torch.Tensor): + flex_attention(x, x, x).sum().backward() + + for a, b in zip([12, 24, 48], [64, 128, 256]): + v = torch.zeros( + 1, + 1, + a * b, + b, + dtype=torch.bfloat16, + device="cuda", + requires_grad=True, + ) + fwd_bwd(v) + yield v.grad + + # TODO: Dynamo graph breaks on torch.ops.higher_order.flex_attention_backward + self.check_output_and_recompiles( + fn, count=3, compiler_fn=make_compiler_fn(fullgraph=False) + ) + @unittest.expectedFailure def test_saved_tensor_unpack_hook_ordering(self): # not the correct behaviour, I'm just preventing this from changing silently diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index 8e2dff14401..7448a5eb598 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -7,6 +7,7 @@ from torch import Tensor from torch._C import DispatchKey from torch._higher_order_ops.utils import ( _has_potential_branch_input_mutation, + _maybe_reenter_make_fx, autograd_not_implemented, reenter_make_fx, save_tensors_and_symints_for_backward, @@ -945,11 +946,14 @@ def trace_flex_attention_backward( mask_example_vals = [query.new_zeros((), dtype=torch.int) for _ in range(4)] mask_graph = block_mask[-1] with TransformGetItemToIndex(): - fw_graph = reenter_make_fx(fw_graph)(*fw_example_vals, *score_mod_other_buffers) - joint_graph = reenter_make_fx(joint_graph)( + # There's no active make_fx during the compiled autograd graph's initial capture + fw_graph = _maybe_reenter_make_fx(fw_graph)( + *fw_example_vals, *score_mod_other_buffers + ) + joint_graph = _maybe_reenter_make_fx(joint_graph)( *bw_example_vals, *score_mod_other_buffers ) - mask_graph = reenter_make_fx(mask_graph)( + mask_graph = _maybe_reenter_make_fx(mask_graph)( *mask_example_vals, *mask_mod_other_buffers ) assert isinstance(proxy_mode.tracer, torch.fx.Tracer)