pytorch/torch/distributed/_composable
Andrew Gu aa61e251d4 [FSDP2] Added shard_placement_fn arg (#137496)
## Overview
This PR adds a `shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]` arg to `fully_shard` that allows users to specify FSDP sharding on a nonzero tensor dim. If doing so, then the tensor dim size must be divisible by the FSDP shard world size.

```
# Example:
def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
    largest_dim = largest_dim_size = -1
    for dim, dim_size in enumerate(param.shape):
        if dim_size > largest_dim_size:
            largest_dim = dim
            largest_dim_size = dim_size
    return Shard(largest_dim)

fully_shard(module, shard_placement_fn=shard_placement_fn)
```

## Follow-Ups
- **Copy kernels:** For all-gather copy-out, we currently copy-out to temporaries and then chunk-dim-0 -> cat-shard-dim, incurring an extra copy for parameters sharded on nonzero tensor dim. Similarly, for reduce-scatter copy-in, we currently chunk-shard-dim -> cat-dim-0, incurring an extra copy for gradients sharded on nonzero tensor dim. @yifuwang  has ideas for adding additional split size args to the copy ops that allows fusing these extra copies into the existing all-gather copy-out and reduce-scatter copy-in.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137496
Approved by: https://github.com/weifengpy
ghstack dependencies: #137593
2024-10-09 19:13:32 +00:00
..
fsdp [FSDP2] Added shard_placement_fn arg (#137496) 2024-10-09 19:13:32 +00:00
__init__.py
checkpoint_activation.py Add None return type to init (#132335) 2024-08-01 15:26:45 +00:00
contract.py Add None return type to init (#132335) 2024-08-01 15:26:45 +00:00
fully_shard.py [BE] mypy: disallow untyped decorators (#131428) 2024-07-23 21:50:55 +00:00
replicate.py [DDP][FSDP2] keep DTensor params for replicate(fully_shard) (#133059) 2024-08-09 18:38:05 +00:00