diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 7fb48d30499..b2e9231aa90 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1527,6 +1527,54 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): tolerance = Tolerances(atol=2e-1, rtol=2e-1) torch.testing.assert_close(out, out2, atol=tolerance.atol, rtol=tolerance.rtol) + @supported_platform + def test_multiple_mask_calls(self): + if TEST_WITH_ROCM: + self.skipTest( + "ROCM BUG SEE: https://github.com/pytorch/pytorch/issues/140855" + ) + # Create inputs + query = torch.randn( + (1, 4, 512, 64), dtype=torch.float32, device="cuda", requires_grad=True + ) + key = torch.randn( + (1, 4, 512, 64), dtype=torch.float32, device="cuda", requires_grad=True + ) + value = torch.randn( + (1, 4, 512, 64), dtype=torch.float32, device="cuda", requires_grad=True + ) + + window_size = 32 + + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + def causal_mask_slidewindow_mod(b, h, q_idx, kv_idx): + return (q_idx >= kv_idx) & (q_idx <= kv_idx + window_size) + + mask1 = create_block_mask(causal_mask, 1, None, 512, 512, _compile=False) + mask2 = create_block_mask( + causal_mask_slidewindow_mod, 1, None, 512, 512, _compile=False + ) + + def f(q, k, v): + out1 = flex_attention(q, k, v, block_mask=mask1) + out2 = flex_attention(q, k, v, block_mask=mask2) + return out1 + out2 + + f_compiled = torch.compile(f, fullgraph=True) + + out = f(query, key, value) + out_compiled = f_compiled(query, key, value) + + grads = torch.autograd.grad((out,), (query, key, value), torch.ones_like(out)) + grads_compile = torch.autograd.grad( + (out_compiled,), (query, key, value), torch.ones_like(out_compiled) + ) + + for grad, grad_compiled in zip(grads, grads_compile): + torch.testing.assert_close(grad, grad_compiled, atol=3e-2, rtol=3e-2) + @supported_platform def test_multiple_score_mod_calls2(self): query = torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda") @@ -3184,21 +3232,21 @@ class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module): def forward(self, primals_1: "f64[2, 2, 128, 4]", primals_2: "f64[2, 2, 128, 4]", primals_3: "f64[2, 2, 128, 4]", full: "i32[1, 1, 1]", full_default: "i32[1, 1, 1, 1]", convert_element_type: "i32[1, 1, 1]", convert_element_type_1: "i32[1, 1, 1, 1]", getitem_2: "f64[2, 2, 128, 4]", getitem_3: "f32[2, 2, 128]", tangents_1: "f64[2, 2, 128, 4]"): full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) - fw_graph = self.fw_graph - joint_graph = self.joint_graph - mask_graph = self.mask_graph - flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph, joint_graph, (full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph = joint_graph = full = full_default = convert_element_type = convert_element_type_1 = mask_graph = None + fw_graph0 = self.fw_graph0 + joint_graph0 = self.joint_graph0 + mask_graph0 = self.mask_graph0 + flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None getitem_4: "f64[2, 2, 128, 4]" = flex_attention_backward[0] getitem_5: "f64[2, 2, 128, 4]" = flex_attention_backward[1] getitem_6: "f64[2, 2, 128, 4]" = flex_attention_backward[2]; flex_attention_backward = None return (getitem_4, getitem_5, getitem_6) - class fw_graph(torch.nn.Module): + class fw_graph0(torch.nn.Module): def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]"): mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None return mul - class joint_graph(torch.nn.Module): + class joint_graph0(torch.nn.Module): def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]", arg5_1: "f64[]"): mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); mul = None mul_1: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1) @@ -3206,7 +3254,7 @@ class GraphModule(torch.nn.Module): add: "f64[]" = torch.ops.aten.add.Tensor(mul_2, mul_1); mul_2 = mul_1 = None return [add, None, None, None, None] - class mask_graph(torch.nn.Module): + class mask_graph0(torch.nn.Module): def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"): full: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) return full diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index 54ba0cb6e5d..932d7440ab4 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -946,9 +946,14 @@ def trace_flex_attention_backward( ) assert isinstance(proxy_mode.tracer, torch.fx.Tracer) block_mask = block_mask[:-1] + (mask_graph,) - proxy_mode.tracer.root.register_module("fw_graph", fw_graph) # type: ignore[arg-type] - proxy_mode.tracer.root.register_module("joint_graph", joint_graph) - proxy_mode.tracer.root.register_module("mask_graph", mask_graph) + + qualname = proxy_mode.tracer.get_fresh_qualname("fw_graph") + proxy_mode.tracer.root.register_module(qualname, fw_graph) # type: ignore[arg-type] + qualname = proxy_mode.tracer.get_fresh_qualname("joint_graph") + proxy_mode.tracer.root.register_module(qualname, joint_graph) + qualname = proxy_mode.tracer.get_fresh_qualname("mask_graph") + proxy_mode.tracer.root.register_module(qualname, mask_graph) + node_args = ( query, key,