[cp] override compute_log_sumexp to True for aten._scaled_dot_product_efficient_attention.default if False (#145421)

## Description
Our current CP doesn't support efficient attention when `compute_log_sumexp=False`. `compute_log_sumexp=False` only if that `requires_grad=False` and since PP's [shape inference](d95a6babcc/torch/distributed/pipelining/stage.py (L1387)) happens under `torch.no_grad()` context , we need to override `compute_log_sumexp` to `True` in our CP attention implementation.

## Test
- Test PP+FSDP+CP w/ `mixed_precision = "float32"` in torchtitan

- `pytest test/distributed/tensor/test_attention.py -s -k test_ring_attention_sdpa`

Before:
<img width="1880" alt="image" src="https://github.com/user-attachments/assets/872ff583-295e-4751-a280-cf7f2d41c61a" />

After:
<img width="2988" alt="image" src="https://github.com/user-attachments/assets/4bdcc2e5-22a5-427a-91a5-82206d5bd78f" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145421
Approved by: https://github.com/H-Huang, https://github.com/tianyu-l
This commit is contained in:
Xilun Wu 2025-01-23 14:00:28 -08:00 committed by PyTorch MergeBot
parent 53fc921ce2
commit 2ce70da96c
2 changed files with 41 additions and 25 deletions

View file

@ -73,6 +73,7 @@ class RingAttentionTest(DTensorTestBase):
"backend": backends,
"load_balance": [True, False],
"rotater": [_RotateMethod.ALL_TO_ALL, _RotateMethod.ALL_GATHER],
"test_forward_only": [True, False],
},
self._test_ring_attention_sdpa,
)
@ -84,7 +85,17 @@ class RingAttentionTest(DTensorTestBase):
backend: SDPBackend,
load_balance: bool,
rotater: _RotateMethod,
test_forward_only: bool,
) -> None:
def fn_eval(fn, *args, **kwargs):
if test_forward_only:
with torch.no_grad():
return fn(*args, **kwargs)
else:
out = fn(*args, **kwargs)
out.sum().backward()
return out
if load_balance and not is_causal:
return
@ -130,8 +141,7 @@ class RingAttentionTest(DTensorTestBase):
dist.broadcast(v, src=0)
with sdpa_kernel(backend):
out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
out.sum().backward()
out = fn_eval(F.scaled_dot_product_attention, q, k, v, is_causal=is_causal)
cp_q = q.detach().clone()
cp_k = k.detach().clone()
@ -158,26 +168,23 @@ class RingAttentionTest(DTensorTestBase):
else:
fn = F.scaled_dot_product_attention
cp_out = fn(cp_q, cp_k, cp_v, is_causal=is_causal)
cp_out.sum().backward()
cp_out = fn_eval(fn, cp_q, cp_k, cp_v, is_causal=is_causal)
if not compiled and rotater == _RotateMethod.ALL_TO_ALL:
# Compiler and CommDebugMode do not work well together.
expect_all2all_count = (
self.world_size - 1
if test_forward_only
else self.world_size * 3 - 2
)
self.assertDictEqual(
comm_mode.get_comm_counts(),
{
c10d_functional.all_to_all_single: self.world_size * 3
- 2
},
{c10d_functional.all_to_all_single: expect_all2all_count},
)
# Due to numerical error, we need to choose different atol for different
# attention kernels
cp_out, cp_dq, cp_dk, cp_dv = context_parallel_unshard(
device_mesh,
[cp_out, cp_q.grad, cp_k.grad, cp_v.grad],
[2, 2, 2, 2],
)
(cp_out,) = context_parallel_unshard(device_mesh, [cp_out], [2])
atol = (
1e-08
if backend == SDPBackend.EFFICIENT_ATTENTION
@ -185,18 +192,25 @@ class RingAttentionTest(DTensorTestBase):
)
self.assertTrue(torch.allclose(out, cp_out, atol=atol))
atol = (
2e-06
if backend == SDPBackend.EFFICIENT_ATTENTION
else 8e-3 * self.world_size
)
self.assertTrue(torch.allclose(q.grad, cp_dq, atol=atol))
self.assertTrue(torch.allclose(k.grad, cp_dk, atol=atol))
self.assertTrue(torch.allclose(v.grad, cp_dv, atol=atol))
if not test_forward_only:
cp_dq, cp_dk, cp_dv = context_parallel_unshard(
device_mesh,
[cp_q.grad, cp_k.grad, cp_v.grad],
[2, 2, 2],
)
atol = (
2e-06
if backend == SDPBackend.EFFICIENT_ATTENTION
else 8e-3 * self.world_size
)
self.assertTrue(torch.allclose(q.grad, cp_dq, atol=atol))
self.assertTrue(torch.allclose(k.grad, cp_dk, atol=atol))
self.assertTrue(torch.allclose(v.grad, cp_dv, atol=atol))
cp_q.grad = None
cp_k.grad = None
cp_v.grad = None
cp_q.grad = None
cp_k.grad = None
cp_v.grad = None
cp_q.requires_grad = False
cp_k.requires_grad = False
cp_v.requires_grad = False

View file

@ -225,8 +225,10 @@ def _scaled_dot_product_ring_efficient_attention(
) -> tuple[torch.Tensor, ...]:
if attn_bias is not None:
raise NotImplementedError("attn_bias is not supported yet")
if not compute_log_sumexp:
raise NotImplementedError("compute_log_sumexp must be set")
# CP requires compute_log_sumexp to be True because it always merges LSE
compute_log_sumexp = True
seq_dim = 2
return _templated_ring_attention(