mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[cp] override compute_log_sumexp to True for aten._scaled_dot_product_efficient_attention.default if False (#145421)
## Description
Our current CP doesn't support efficient attention when `compute_log_sumexp=False`. `compute_log_sumexp=False` only if that `requires_grad=False` and since PP's [shape inference](d95a6babcc/torch/distributed/pipelining/stage.py (L1387)) happens under `torch.no_grad()` context , we need to override `compute_log_sumexp` to `True` in our CP attention implementation.
## Test
- Test PP+FSDP+CP w/ `mixed_precision = "float32"` in torchtitan
- `pytest test/distributed/tensor/test_attention.py -s -k test_ring_attention_sdpa`
Before:
<img width="1880" alt="image" src="https://github.com/user-attachments/assets/872ff583-295e-4751-a280-cf7f2d41c61a" />
After:
<img width="2988" alt="image" src="https://github.com/user-attachments/assets/4bdcc2e5-22a5-427a-91a5-82206d5bd78f" />
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145421
Approved by: https://github.com/H-Huang, https://github.com/tianyu-l
This commit is contained in:
parent
53fc921ce2
commit
2ce70da96c
2 changed files with 41 additions and 25 deletions
|
|
@ -73,6 +73,7 @@ class RingAttentionTest(DTensorTestBase):
|
|||
"backend": backends,
|
||||
"load_balance": [True, False],
|
||||
"rotater": [_RotateMethod.ALL_TO_ALL, _RotateMethod.ALL_GATHER],
|
||||
"test_forward_only": [True, False],
|
||||
},
|
||||
self._test_ring_attention_sdpa,
|
||||
)
|
||||
|
|
@ -84,7 +85,17 @@ class RingAttentionTest(DTensorTestBase):
|
|||
backend: SDPBackend,
|
||||
load_balance: bool,
|
||||
rotater: _RotateMethod,
|
||||
test_forward_only: bool,
|
||||
) -> None:
|
||||
def fn_eval(fn, *args, **kwargs):
|
||||
if test_forward_only:
|
||||
with torch.no_grad():
|
||||
return fn(*args, **kwargs)
|
||||
else:
|
||||
out = fn(*args, **kwargs)
|
||||
out.sum().backward()
|
||||
return out
|
||||
|
||||
if load_balance and not is_causal:
|
||||
return
|
||||
|
||||
|
|
@ -130,8 +141,7 @@ class RingAttentionTest(DTensorTestBase):
|
|||
dist.broadcast(v, src=0)
|
||||
|
||||
with sdpa_kernel(backend):
|
||||
out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
|
||||
out.sum().backward()
|
||||
out = fn_eval(F.scaled_dot_product_attention, q, k, v, is_causal=is_causal)
|
||||
|
||||
cp_q = q.detach().clone()
|
||||
cp_k = k.detach().clone()
|
||||
|
|
@ -158,26 +168,23 @@ class RingAttentionTest(DTensorTestBase):
|
|||
else:
|
||||
fn = F.scaled_dot_product_attention
|
||||
|
||||
cp_out = fn(cp_q, cp_k, cp_v, is_causal=is_causal)
|
||||
cp_out.sum().backward()
|
||||
cp_out = fn_eval(fn, cp_q, cp_k, cp_v, is_causal=is_causal)
|
||||
|
||||
if not compiled and rotater == _RotateMethod.ALL_TO_ALL:
|
||||
# Compiler and CommDebugMode do not work well together.
|
||||
expect_all2all_count = (
|
||||
self.world_size - 1
|
||||
if test_forward_only
|
||||
else self.world_size * 3 - 2
|
||||
)
|
||||
self.assertDictEqual(
|
||||
comm_mode.get_comm_counts(),
|
||||
{
|
||||
c10d_functional.all_to_all_single: self.world_size * 3
|
||||
- 2
|
||||
},
|
||||
{c10d_functional.all_to_all_single: expect_all2all_count},
|
||||
)
|
||||
|
||||
# Due to numerical error, we need to choose different atol for different
|
||||
# attention kernels
|
||||
cp_out, cp_dq, cp_dk, cp_dv = context_parallel_unshard(
|
||||
device_mesh,
|
||||
[cp_out, cp_q.grad, cp_k.grad, cp_v.grad],
|
||||
[2, 2, 2, 2],
|
||||
)
|
||||
(cp_out,) = context_parallel_unshard(device_mesh, [cp_out], [2])
|
||||
atol = (
|
||||
1e-08
|
||||
if backend == SDPBackend.EFFICIENT_ATTENTION
|
||||
|
|
@ -185,18 +192,25 @@ class RingAttentionTest(DTensorTestBase):
|
|||
)
|
||||
self.assertTrue(torch.allclose(out, cp_out, atol=atol))
|
||||
|
||||
atol = (
|
||||
2e-06
|
||||
if backend == SDPBackend.EFFICIENT_ATTENTION
|
||||
else 8e-3 * self.world_size
|
||||
)
|
||||
self.assertTrue(torch.allclose(q.grad, cp_dq, atol=atol))
|
||||
self.assertTrue(torch.allclose(k.grad, cp_dk, atol=atol))
|
||||
self.assertTrue(torch.allclose(v.grad, cp_dv, atol=atol))
|
||||
if not test_forward_only:
|
||||
cp_dq, cp_dk, cp_dv = context_parallel_unshard(
|
||||
device_mesh,
|
||||
[cp_q.grad, cp_k.grad, cp_v.grad],
|
||||
[2, 2, 2],
|
||||
)
|
||||
atol = (
|
||||
2e-06
|
||||
if backend == SDPBackend.EFFICIENT_ATTENTION
|
||||
else 8e-3 * self.world_size
|
||||
)
|
||||
self.assertTrue(torch.allclose(q.grad, cp_dq, atol=atol))
|
||||
self.assertTrue(torch.allclose(k.grad, cp_dk, atol=atol))
|
||||
self.assertTrue(torch.allclose(v.grad, cp_dv, atol=atol))
|
||||
|
||||
cp_q.grad = None
|
||||
cp_k.grad = None
|
||||
cp_v.grad = None
|
||||
|
||||
cp_q.grad = None
|
||||
cp_k.grad = None
|
||||
cp_v.grad = None
|
||||
cp_q.requires_grad = False
|
||||
cp_k.requires_grad = False
|
||||
cp_v.requires_grad = False
|
||||
|
|
|
|||
|
|
@ -225,8 +225,10 @@ def _scaled_dot_product_ring_efficient_attention(
|
|||
) -> tuple[torch.Tensor, ...]:
|
||||
if attn_bias is not None:
|
||||
raise NotImplementedError("attn_bias is not supported yet")
|
||||
|
||||
if not compute_log_sumexp:
|
||||
raise NotImplementedError("compute_log_sumexp must be set")
|
||||
# CP requires compute_log_sumexp to be True because it always merges LSE
|
||||
compute_log_sumexp = True
|
||||
|
||||
seq_dim = 2
|
||||
return _templated_ring_attention(
|
||||
|
|
|
|||
Loading…
Reference in a new issue