mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
## Overview
This PR adds a `shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]` arg to `fully_shard` that allows users to specify FSDP sharding on a nonzero tensor dim. If doing so, then the tensor dim size must be divisible by the FSDP shard world size.
```
# Example:
def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
largest_dim = largest_dim_size = -1
for dim, dim_size in enumerate(param.shape):
if dim_size > largest_dim_size:
largest_dim = dim
largest_dim_size = dim_size
return Shard(largest_dim)
fully_shard(module, shard_placement_fn=shard_placement_fn)
```
## Follow-Ups
- **Copy kernels:** For all-gather copy-out, we currently copy-out to temporaries and then chunk-dim-0 -> cat-shard-dim, incurring an extra copy for parameters sharded on nonzero tensor dim. Similarly, for reduce-scatter copy-in, we currently chunk-shard-dim -> cat-dim-0, incurring an extra copy for gradients sharded on nonzero tensor dim. @yifuwang has ideas for adding additional split size args to the copy ops that allows fusing these extra copies into the existing all-gather copy-out and reduce-scatter copy-in.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137496
Approved by: https://github.com/weifengpy
ghstack dependencies: #137593
|
||
|---|---|---|
| .. | ||
| _composable | ||
| _shard | ||
| _sharded_tensor | ||
| _sharding_spec | ||
| _symmetric_memory | ||
| _tensor | ||
| _tools | ||
| algorithms | ||
| autograd | ||
| benchmarks | ||
| checkpoint | ||
| elastic | ||
| examples | ||
| fsdp | ||
| launcher | ||
| nn | ||
| optim | ||
| pipelining | ||
| rpc | ||
| tensor | ||
| __init__.py | ||
| _checkpointable.py | ||
| _composable_state.py | ||
| _functional_collectives.py | ||
| _functional_collectives_impl.py | ||
| _state_dict_utils.py | ||
| argparse_util.py | ||
| c10d_logger.py | ||
| collective_utils.py | ||
| constants.py | ||
| CONTRIBUTING.md | ||
| device_mesh.py | ||
| distributed_c10d.py | ||
| launch.py | ||
| logging_handlers.py | ||
| remote_device.py | ||
| rendezvous.py | ||
| run.py | ||
| utils.py | ||