pytorch/torch/distributed
Andrew Gu aa61e251d4 [FSDP2] Added shard_placement_fn arg (#137496)
## Overview
This PR adds a `shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]` arg to `fully_shard` that allows users to specify FSDP sharding on a nonzero tensor dim. If doing so, then the tensor dim size must be divisible by the FSDP shard world size.

```
# Example:
def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
    largest_dim = largest_dim_size = -1
    for dim, dim_size in enumerate(param.shape):
        if dim_size > largest_dim_size:
            largest_dim = dim
            largest_dim_size = dim_size
    return Shard(largest_dim)

fully_shard(module, shard_placement_fn=shard_placement_fn)
```

## Follow-Ups
- **Copy kernels:** For all-gather copy-out, we currently copy-out to temporaries and then chunk-dim-0 -> cat-shard-dim, incurring an extra copy for parameters sharded on nonzero tensor dim. Similarly, for reduce-scatter copy-in, we currently chunk-shard-dim -> cat-dim-0, incurring an extra copy for gradients sharded on nonzero tensor dim. @yifuwang  has ideas for adding additional split size args to the copy ops that allows fusing these extra copies into the existing all-gather copy-out and reduce-scatter copy-in.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137496
Approved by: https://github.com/weifengpy
ghstack dependencies: #137593
2024-10-09 19:13:32 +00:00
..
_composable [FSDP2] Added shard_placement_fn arg (#137496) 2024-10-09 19:13:32 +00:00
_shard [BE]: Update mypy to 1.11.2 (#133816) 2024-09-16 19:44:11 +00:00
_sharded_tensor
_sharding_spec
_symmetric_memory [async-tp] fix a race condition that can cause silent correctness issue (#137199) 2024-10-03 10:42:37 +00:00
_tensor [reland][dtensor] move DTensor to public namespace (#134203) 2024-09-08 17:08:40 +00:00
_tools
algorithms [BE]: Update mypy to 1.11.2 (#133816) 2024-09-16 19:44:11 +00:00
autograd
benchmarks
checkpoint [Distributed] fix FileSystemWriter __init__ (#136135) 2024-09-16 19:11:08 +00:00
elastic Fix rendezvous error due to EtcdStore get method not waiting in some cases (#137056) 2024-10-02 01:45:00 +00:00
examples
fsdp Update real device in FSDP state_dict_utils (#134994) 2024-09-17 04:39:08 +00:00
launcher
nn
optim [BE]: Update mypy to 1.11.2 (#133816) 2024-09-16 19:44:11 +00:00
pipelining unflatten with specialized graphs per submodule call (#137013) 2024-10-03 00:55:44 +00:00
rpc
tensor Allow parallelize_module to get device_mesh from ambient context (#134247) 2024-10-09 00:19:03 +00:00
__init__.py
_checkpointable.py
_composable_state.py
_functional_collectives.py [BE]: Update mypy to 1.11.2 (#133816) 2024-09-16 19:44:11 +00:00
_functional_collectives_impl.py
_state_dict_utils.py [DSD] Fix loading uneven full tensor into sharded state dict (#136365) 2024-09-23 16:35:58 +00:00
argparse_util.py
c10d_logger.py
collective_utils.py
constants.py
CONTRIBUTING.md
device_mesh.py [DeviceMesh][EZ] Add group description to new group (#136558) 2024-09-28 03:09:41 +00:00
distributed_c10d.py [c10d] Fix the device query story of ProcessGroup (#136790) 2024-10-03 01:36:22 +00:00
launch.py
logging_handlers.py
remote_device.py
rendezvous.py [reland] [torchelastic][c10d] Fix store prefix race in rendezvous (#136768) 2024-09-26 17:37:07 +00:00
run.py
utils.py [BE] fix circular import in torch/distributed/utils.py (#136286) 2024-09-22 20:54:12 +00:00