pytorch/test/distributed
Andrew Gu 662a8cf74d [FSDP][8/N] Simplify addr padding internals (#97796)
This is a follow-up to the last PR to greatly simplify the approach. This should be much cleaner.

**Details**
Let `N` denote the number of original parameters flattened into a given flat parameter with `M` extra padding tensors.
- `_numels_with_padding`: length `N + M`
- `_is_padding_mask`: length `N + M`
- `_numels`, `_param_infos`, `_shapes`, `_fqns`, `_param_extensions`: length `N`

`_shard_param_indices` and `_shard_param_offsets` were used to determine (1) if a given original parameter is in the local shard and if so, then (2) what is its offset in the _sharded_ flat parameter, and (3) how many numel are in the _sharded_ flat parameter.

This PR reworks how to achieve (1), (2), and (3) to allow for simplifying the previously mentioned data structures. In particular, it saves one extra tuple `_shard_param_infos: Tuple[_ShardParamInfo, ...]` of length `N` where each `_ShardParamInfo` entry gives exactly the needed info. For example, the offset into the sharded flat parameter is now pre-computed, so we do not need to do `offset = 0; offset += numel_in_shard` over a `for` loop each time now.

For optimizer state dict, `FSDPParamInfo.param_indices` now maps to the indexes with respect to the length `N` data structures, not the length `N + M` ones. The only purpose of `param_indices` is to be able to index into `flat_param._shard_param_infos[i]` to get the contained info to flatten the unsharded original parameter optimizer state and extract the part in the local shard.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97796
Approved by: https://github.com/rohan-varma
2023-03-28 22:19:44 +00:00
..
_composable Bump black version to 23.1.0 (#96578) 2023-03-15 06:27:59 +00:00
_shard [10/N] Remove ST init ops (#96985) 2023-03-22 20:26:18 +00:00
_spmd Allow DTensor to trigger collecives before inplace ops (#97787) 2023-03-28 21:06:51 +00:00
_tensor [ci] disable some dtensor tests (#97358) 2023-03-23 22:08:31 +00:00
_tools
algorithms Fix "sandcastle_skip_if decorator name is confusing" (#95649) 2023-03-03 09:29:40 +00:00
bin
checkpoint [DCP] Expose create_read_items_for_chunk_list helper. (#97570) 2023-03-28 02:25:04 +00:00
elastic Fix "sandcastle_skip_if decorator name is confusing" (#95649) 2023-03-03 09:29:40 +00:00
fsdp [FSDP][8/N] Simplify addr padding internals (#97796) 2023-03-28 22:19:44 +00:00
launcher Fix "sandcastle_skip_if decorator name is confusing" (#95649) 2023-03-03 09:29:40 +00:00
nn/jit
optim [Optim in backward] register_hook=False API (#95096) 2023-03-15 14:33:13 +00:00
pipeline/sync Run tests in USE_PYTEST_LIST through run_tests (#95659) 2023-02-28 22:09:01 +00:00
rpc [BE] [3/3] Rewrite super() calls in test (#94592) 2023-02-12 22:20:53 +00:00
tensor/parallel Fix typos under torch/distributed directory (#95638) 2023-03-27 21:13:44 +00:00
argparse_util_test.py
test_c10d_common.py [BE] Remove unnecessary dict comprehensions (#97116) 2023-03-20 00:56:57 +00:00
test_c10d_error_logger.py [BE] [3/3] Rewrite super() calls in test (#94592) 2023-02-12 22:20:53 +00:00
test_c10d_gloo.py Fix "sandcastle_skip_if decorator name is confusing" (#95649) 2023-03-03 09:29:40 +00:00
test_c10d_nccl.py Rewrite NCCL watchdog to more reliably throw timeout (#97066) 2023-03-25 04:30:20 +00:00
test_c10d_object_collectives.py [BE] [3/3] Rewrite super() calls in test (#94592) 2023-02-12 22:20:53 +00:00
test_c10d_pypg.py [BE] [3/3] Rewrite super() calls in test (#94592) 2023-02-12 22:20:53 +00:00
test_c10d_spawn.py [BE] [3/3] Rewrite super() calls in test (#94592) 2023-02-12 22:20:53 +00:00
test_c10d_spawn_gloo.py Fix "sandcastle_skip_if decorator name is confusing" (#95649) 2023-03-03 09:29:40 +00:00
test_c10d_spawn_nccl.py Fix "sandcastle_skip_if decorator name is confusing" (#95649) 2023-03-03 09:29:40 +00:00
test_c10d_spawn_ucc.py Fix "sandcastle_skip_if decorator name is confusing" (#95649) 2023-03-03 09:29:40 +00:00
test_data_parallel.py Fix "sandcastle_skip_if decorator name is confusing" (#95649) 2023-03-03 09:29:40 +00:00
test_distributed_spawn.py
test_dynamo_distributed.py Make dynamo-FSDP skip guards (#97463) 2023-03-28 04:04:34 +00:00
test_functional_api.py [PTD] Introduce tracing friendly collectives. (#93990) 2023-02-16 15:35:01 +00:00
test_inductor_collectives.py Provide more informative kernel names in Inductor (#95940) 2023-03-07 18:02:10 +00:00
test_launcher.py [BE] Prefer dash over underscore in command-line options (#94505) 2023-02-09 20:16:49 +00:00
test_multi_threaded_pg.py Add gather to MTPG (#97555) 2023-03-27 19:37:02 +00:00
test_nccl.py Reduce pytest blocklist (#96016) 2023-03-07 18:30:27 +00:00
test_pg_wrapper.py [BE] [3/3] Rewrite super() calls in test (#94592) 2023-02-12 22:20:53 +00:00
test_store.py [BE] Remove dependency on six and future (#94709) 2023-02-14 09:14:14 +00:00