mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
This PR creates these `GroupedSchedulerNode`s: - One for each all-gather code block (cast + copy-in + all-gather) - One for each all-gather-wait code block (all-gather-wait + copy-out) - One for each reduce-scatter code block (copy-in + reduce-scatter) - One for each reduce-scatter-wait code block (reduce-scatter-wait) This serves two goals: - Prevent outside ops from being fused into these op groups, in order to have more predicable memory usage. - Make it easier to specify the dependency e.g. from `i+1` all-gather group node to the `i` all-gather-wait group node, to enforce FSDP2 comm ordering (i.e. "serialization of comms"). The actual "reorder-for-FSDP-compute-comm-overlap" PR will come next. Test commands: - `pytest -rA test/distributed/test_compute_comm_reordering.py::TestComputeCommReorderingMultiProc` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor` Pull Request resolved: https://github.com/pytorch/pytorch/pull/131510 Approved by: https://github.com/yifuwang |
||
|---|---|---|
| .. | ||
| _internal | ||
| __init__.py | ||
| _comparison.py | ||
| _creation.py | ||
| _utils.py | ||