mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
This is a follow-up to the last PR to greatly simplify the approach. This should be much cleaner. **Details** Let `N` denote the number of original parameters flattened into a given flat parameter with `M` extra padding tensors. - `_numels_with_padding`: length `N + M` - `_is_padding_mask`: length `N + M` - `_numels`, `_param_infos`, `_shapes`, `_fqns`, `_param_extensions`: length `N` `_shard_param_indices` and `_shard_param_offsets` were used to determine (1) if a given original parameter is in the local shard and if so, then (2) what is its offset in the _sharded_ flat parameter, and (3) how many numel are in the _sharded_ flat parameter. This PR reworks how to achieve (1), (2), and (3) to allow for simplifying the previously mentioned data structures. In particular, it saves one extra tuple `_shard_param_infos: Tuple[_ShardParamInfo, ...]` of length `N` where each `_ShardParamInfo` entry gives exactly the needed info. For example, the offset into the sharded flat parameter is now pre-computed, so we do not need to do `offset = 0; offset += numel_in_shard` over a `for` loop each time now. For optimizer state dict, `FSDPParamInfo.param_indices` now maps to the indexes with respect to the length `N` data structures, not the length `N + M` ones. The only purpose of `param_indices` is to be able to index into `flat_param._shard_param_infos[i]` to get the contained info to flatten the unsharded original parameter optimizer state and extract the part in the local shard. Pull Request resolved: https://github.com/pytorch/pytorch/pull/97796 Approved by: https://github.com/rohan-varma |
||
|---|---|---|
| .. | ||
| _composable | ||
| _shard | ||
| _spmd | ||
| _tensor | ||
| _tools | ||
| algorithms | ||
| bin | ||
| checkpoint | ||
| elastic | ||
| fsdp | ||
| launcher | ||
| nn/jit | ||
| optim | ||
| pipeline/sync | ||
| rpc | ||
| tensor/parallel | ||
| argparse_util_test.py | ||
| test_c10d_common.py | ||
| test_c10d_error_logger.py | ||
| test_c10d_gloo.py | ||
| test_c10d_nccl.py | ||
| test_c10d_object_collectives.py | ||
| test_c10d_pypg.py | ||
| test_c10d_spawn.py | ||
| test_c10d_spawn_gloo.py | ||
| test_c10d_spawn_nccl.py | ||
| test_c10d_spawn_ucc.py | ||
| test_data_parallel.py | ||
| test_distributed_spawn.py | ||
| test_dynamo_distributed.py | ||
| test_functional_api.py | ||
| test_inductor_collectives.py | ||
| test_launcher.py | ||
| test_multi_threaded_pg.py | ||
| test_nccl.py | ||
| test_pg_wrapper.py | ||
| test_store.py | ||