pytorch/torch/distributed/_composable
Andrew Gu 57fba6fd86 [FSDP][9/N] Introduce CustomPolicy (#104986)
This PR adds a new `CustomPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired.

The API is as follows:
```
def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]:
    ...
policy = CustomPolicy(lambda_fn)
```
The `lambda_fn` can return:
- `False` or `{}` to indicate no wrapping
- `True` to indicate wrapping while inheriting the root's FSDP kwargs
- Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root

---

After this PR, the follow-up work items for auto wrapping are:
1. Add shared parameter validation
2. (Longer-term / exploratory) Add a policy that provides a reasonable auto wrapping with "minimal" user input

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104986
Approved by: https://github.com/ezyang
ghstack dependencies: #104427, #104967, #104999, #104969
2023-08-03 12:46:36 +00:00
..
__init__.py
checkpoint_activation.py [Composable] Use non-reentrant generator, remove reentrant (#105176) 2023-07-26 07:03:03 +00:00
contract.py [FSDP x dynamo] simplify registry keys (#104209) 2023-07-25 07:16:22 +00:00
fully_shard.py [FSDP][9/N] Introduce CustomPolicy (#104986) 2023-08-03 12:46:36 +00:00
replicate.py [Replicate] Add unit test with replicate param names (#102401) 2023-05-31 18:41:03 +00:00