pytorch/torch/distributed/tensor/experimental
Xilun Wu 2ce70da96c [cp] override compute_log_sumexp to True for aten._scaled_dot_product_efficient_attention.default if False (#145421)
## Description
Our current CP doesn't support efficient attention when `compute_log_sumexp=False`. `compute_log_sumexp=False` only if that `requires_grad=False` and since PP's [shape inference](d95a6babcc/torch/distributed/pipelining/stage.py (L1387)) happens under `torch.no_grad()` context , we need to override `compute_log_sumexp` to `True` in our CP attention implementation.

## Test
- Test PP+FSDP+CP w/ `mixed_precision = "float32"` in torchtitan

- `pytest test/distributed/tensor/test_attention.py -s -k test_ring_attention_sdpa`

Before:
<img width="1880" alt="image" src="https://github.com/user-attachments/assets/872ff583-295e-4751-a280-cf7f2d41c61a" />

After:
<img width="2988" alt="image" src="https://github.com/user-attachments/assets/4bdcc2e5-22a5-427a-91a5-82206d5bd78f" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145421
Approved by: https://github.com/H-Huang, https://github.com/tianyu-l
2025-01-24 06:17:54 +00:00
..
__init__.py
_attention.py
_func_map.py
_register_sharding.py
_tp_transform.py