pytorch/torch/testing
Will Feng 9e06572704 [Traceable FSDP2][Inductor] Create grouped nodes for FSDP2 all-gather code block and reduce-scatter code block (after Buffer/Operation split) (#131510)
This PR creates these `GroupedSchedulerNode`s:
- One for each all-gather code block (cast + copy-in + all-gather)
- One for each all-gather-wait code block (all-gather-wait + copy-out)
- One for each reduce-scatter code block (copy-in + reduce-scatter)
- One for each reduce-scatter-wait code block (reduce-scatter-wait)

This serves two goals:
- Prevent outside ops from being fused into these op groups, in order to have more predicable memory usage.
- Make it easier to specify the dependency e.g. from `i+1` all-gather group node to the `i` all-gather-wait group node, to enforce FSDP2 comm ordering (i.e. "serialization of comms").

The actual "reorder-for-FSDP-compute-comm-overlap" PR will come next.

Test commands:
- `pytest -rA  test/distributed/test_compute_comm_reordering.py::TestComputeCommReorderingMultiProc`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131510
Approved by: https://github.com/yifuwang
2024-07-27 08:39:58 +00:00
..
_internal [Traceable FSDP2][Inductor] Create grouped nodes for FSDP2 all-gather code block and reduce-scatter code block (after Buffer/Operation split) (#131510) 2024-07-27 08:39:58 +00:00
__init__.py
_comparison.py
_creation.py
_utils.py