[FlexAttention] Fix multiple calls to flex bug (#140761)

# Summary
Fixes long-standing bug we've had in the backward pass for flex attention. See https://github.com/pytorch/pytorch/issues/135161 for details

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140761
Approved by: https://github.com/Chillee, https://github.com/zou3519
This commit is contained in:
drisspg 2024-11-15 18:11:09 -08:00 committed by PyTorch MergeBot
parent a173186566
commit 0f9eea1329
2 changed files with 63 additions and 10 deletions

View file

@ -1527,6 +1527,54 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
tolerance = Tolerances(atol=2e-1, rtol=2e-1)
torch.testing.assert_close(out, out2, atol=tolerance.atol, rtol=tolerance.rtol)
@supported_platform
def test_multiple_mask_calls(self):
if TEST_WITH_ROCM:
self.skipTest(
"ROCM BUG SEE: https://github.com/pytorch/pytorch/issues/140855"
)
# Create inputs
query = torch.randn(
(1, 4, 512, 64), dtype=torch.float32, device="cuda", requires_grad=True
)
key = torch.randn(
(1, 4, 512, 64), dtype=torch.float32, device="cuda", requires_grad=True
)
value = torch.randn(
(1, 4, 512, 64), dtype=torch.float32, device="cuda", requires_grad=True
)
window_size = 32
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
def causal_mask_slidewindow_mod(b, h, q_idx, kv_idx):
return (q_idx >= kv_idx) & (q_idx <= kv_idx + window_size)
mask1 = create_block_mask(causal_mask, 1, None, 512, 512, _compile=False)
mask2 = create_block_mask(
causal_mask_slidewindow_mod, 1, None, 512, 512, _compile=False
)
def f(q, k, v):
out1 = flex_attention(q, k, v, block_mask=mask1)
out2 = flex_attention(q, k, v, block_mask=mask2)
return out1 + out2
f_compiled = torch.compile(f, fullgraph=True)
out = f(query, key, value)
out_compiled = f_compiled(query, key, value)
grads = torch.autograd.grad((out,), (query, key, value), torch.ones_like(out))
grads_compile = torch.autograd.grad(
(out_compiled,), (query, key, value), torch.ones_like(out_compiled)
)
for grad, grad_compiled in zip(grads, grads_compile):
torch.testing.assert_close(grad, grad_compiled, atol=3e-2, rtol=3e-2)
@supported_platform
def test_multiple_score_mod_calls2(self):
query = torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
@ -3184,21 +3232,21 @@ class GraphModule(torch.nn.Module):
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f64[2, 2, 128, 4]", primals_2: "f64[2, 2, 128, 4]", primals_3: "f64[2, 2, 128, 4]", full: "i32[1, 1, 1]", full_default: "i32[1, 1, 1, 1]", convert_element_type: "i32[1, 1, 1]", convert_element_type_1: "i32[1, 1, 1, 1]", getitem_2: "f64[2, 2, 128, 4]", getitem_3: "f32[2, 2, 128]", tangents_1: "f64[2, 2, 128, 4]"):
full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
fw_graph = self.fw_graph
joint_graph = self.joint_graph
mask_graph = self.mask_graph
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph, joint_graph, (full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph = joint_graph = full = full_default = convert_element_type = convert_element_type_1 = mask_graph = None
fw_graph0 = self.fw_graph0
joint_graph0 = self.joint_graph0
mask_graph0 = self.mask_graph0
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None
getitem_4: "f64[2, 2, 128, 4]" = flex_attention_backward[0]
getitem_5: "f64[2, 2, 128, 4]" = flex_attention_backward[1]
getitem_6: "f64[2, 2, 128, 4]" = flex_attention_backward[2]; flex_attention_backward = None
return (getitem_4, getitem_5, getitem_6)
class fw_graph(torch.nn.Module):
class fw_graph0(torch.nn.Module):
def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]"):
mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None
return mul
class joint_graph(torch.nn.Module):
class joint_graph0(torch.nn.Module):
def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]", arg5_1: "f64[]"):
mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); mul = None
mul_1: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1)
@ -3206,7 +3254,7 @@ class GraphModule(torch.nn.Module):
add: "f64[]" = torch.ops.aten.add.Tensor(mul_2, mul_1); mul_2 = mul_1 = None
return [add, None, None, None, None]
class mask_graph(torch.nn.Module):
class mask_graph0(torch.nn.Module):
def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"):
full: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
return full

View file

@ -946,9 +946,14 @@ def trace_flex_attention_backward(
)
assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
block_mask = block_mask[:-1] + (mask_graph,)
proxy_mode.tracer.root.register_module("fw_graph", fw_graph) # type: ignore[arg-type]
proxy_mode.tracer.root.register_module("joint_graph", joint_graph)
proxy_mode.tracer.root.register_module("mask_graph", mask_graph)
qualname = proxy_mode.tracer.get_fresh_qualname("fw_graph")
proxy_mode.tracer.root.register_module(qualname, fw_graph) # type: ignore[arg-type]
qualname = proxy_mode.tracer.get_fresh_qualname("joint_graph")
proxy_mode.tracer.root.register_module(qualname, joint_graph)
qualname = proxy_mode.tracer.get_fresh_qualname("mask_graph")
proxy_mode.tracer.root.register_module(qualname, mask_graph)
node_args = (
query,
key,