pytorch/test/distributed
Xilun Wu 2ce70da96c [cp] override compute_log_sumexp to True for aten._scaled_dot_product_efficient_attention.default if False (#145421)
## Description
Our current CP doesn't support efficient attention when `compute_log_sumexp=False`. `compute_log_sumexp=False` only if that `requires_grad=False` and since PP's [shape inference](d95a6babcc/torch/distributed/pipelining/stage.py (L1387)) happens under `torch.no_grad()` context , we need to override `compute_log_sumexp` to `True` in our CP attention implementation.

## Test
- Test PP+FSDP+CP w/ `mixed_precision = "float32"` in torchtitan

- `pytest test/distributed/tensor/test_attention.py -s -k test_ring_attention_sdpa`

Before:
<img width="1880" alt="image" src="https://github.com/user-attachments/assets/872ff583-295e-4751-a280-cf7f2d41c61a" />

After:
<img width="2988" alt="image" src="https://github.com/user-attachments/assets/4bdcc2e5-22a5-427a-91a5-82206d5bd78f" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145421
Approved by: https://github.com/H-Huang, https://github.com/tianyu-l
2025-01-24 06:17:54 +00:00
..
_composable PEP585 update - test (#145176) 2025-01-22 04:48:28 +00:00
_shard PEP585 update - test (#145176) 2025-01-22 04:48:28 +00:00
_tools PEP585 update - test (#145176) 2025-01-22 04:48:28 +00:00
algorithms
bin
checkpoint PEP585 update - test (#145176) 2025-01-22 04:48:28 +00:00
elastic PEP585 update - test (#145176) 2025-01-22 04:48:28 +00:00
flight_recorder Revert "Use absolute path path.resolve() -> path.absolute() (#129409)" 2025-01-04 14:17:20 +00:00
fsdp PEP585 update - test (#145176) 2025-01-22 04:48:28 +00:00
launcher PEP585 update - test (#145176) 2025-01-22 04:48:28 +00:00
nn/jit PEP585 update - test (#145176) 2025-01-22 04:48:28 +00:00
optim PEP585 update - test (#145176) 2025-01-22 04:48:28 +00:00
pipelining PEP585 update - test (#145176) 2025-01-22 04:48:28 +00:00
rpc
tensor [cp] override compute_log_sumexp to True for aten._scaled_dot_product_efficient_attention.default if False (#145421) 2025-01-24 06:17:54 +00:00
argparse_util_test.py
test_backends.py
test_c10d_common.py PEP585 update - test (#145176) 2025-01-22 04:48:28 +00:00
test_c10d_functional_native.py PEP585 update - test (#145176) 2025-01-22 04:48:28 +00:00
test_c10d_gloo.py [ROCm] Enable post-merge trunk workflow on MI300 runners; skip and fix MI300 related failed tests (#143673) 2025-01-09 05:18:57 +00:00
test_c10d_logger.py
test_c10d_nccl.py PEP585 update - test (#145176) 2025-01-22 04:48:28 +00:00
test_c10d_object_collectives.py
test_c10d_ops_nccl.py
test_c10d_pypg.py
test_c10d_spawn.py
test_c10d_spawn_gloo.py PEP585 update - test (#145176) 2025-01-22 04:48:28 +00:00
test_c10d_spawn_nccl.py PEP585 update - test (#145176) 2025-01-22 04:48:28 +00:00
test_c10d_spawn_ucc.py PEP585 update - test (#145176) 2025-01-22 04:48:28 +00:00
test_c10d_ucc.py XFAIL test_save_load_checkpoint (#144927) 2025-01-16 07:31:56 +00:00
test_collective_utils.py
test_composability.py composability test cleanup (#145011) 2025-01-18 04:37:12 +00:00
test_compute_comm_reordering.py
test_control_collectives.py
test_data_parallel.py
test_device_mesh.py
test_distributed_spawn.py
test_dynamo_distributed.py PEP585 update - test (#145176) 2025-01-22 04:48:28 +00:00
test_fake_pg.py
test_functional_api.py
test_inductor_collectives.py
test_launcher.py
test_multi_threaded_pg.py
test_nccl.py
test_pg_wrapper.py
test_store.py [BE][CI] bump ruff to 0.8.4 (#143753) 2024-12-24 12:24:10 +00:00
test_symmetric_memory.py ROCm: Enable 4 gpu tests for distributed config (#140319) 2025-01-02 17:22:11 +00:00