pytorch/torch/distributed
Wanchao Liang 5c48ff20b5 AsyncCollectiveTensor: dont sync on view ops (#105240)
AsyncCollectiveTensor is a tensor subclass that is meant to "delay synchronization" when you call into the functional collectives API's. It does this (if I understand correctly) by internally holding an "unsynchronized" version of the tensor, which is the result of the communication op, and internally calling `.wait()` to synchronize the data the next time it is used.

Previously, these wait() calls would happen immediately, because `AsyncCollectiveTensor` gets wrapped by `DTensor()`, which calls `.detach()` on its inner tensor, immediately causing the sync (code: 1518d5eec4/torch/distributed/_tensor/api.py (L207))

AsyncCollectiveTensor shouldn't need to do a synchronization if you try to detach() it though - in fact, it should be fine to avoid synchronizing if you perform any view ops on it (which just require viewing metadata, but not actual data). This PR tries to update `AsyncCollectiveTensor` to delay `wait()` calls whenever the subclass encounters a view op.

Added some light testing, that just runs some DTensor compute followed by view ops, and confirms that the output is still an `AsyncCollectiveTensor` when we call `.to_local()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105240
Approved by: https://github.com/wanchaol, https://github.com/fduwjj, https://github.com/wconstab
2023-08-11 19:20:25 +00:00
..
_composable [FSDP][9/N] Introduce CustomPolicy (#104986) 2023-08-03 12:46:36 +00:00
_shard
_sharded_tensor
_sharding_spec
_spmd
_tensor Make RNGStateTracker support cuda-like device (#106771) 2023-08-10 19:14:33 +00:00
_tools
algorithms [FSDP][8/N] Replace _FSDPPolicy.policy with _Policy._run_policy (#104969) 2023-08-03 12:42:14 +00:00
autograd
benchmarks
checkpoint [DCP] Modify tensor saving logic in DCP (#106415) 2023-08-09 00:16:10 +00:00
elastic Format: fixing multiple string concatenation in single line (#106013) 2023-07-26 18:39:18 +00:00
examples
fsdp [FSDP] Fix train -> EMA -> eval with mixed precision (#106858) 2023-08-10 19:32:43 +00:00
launcher
nn
optim [Optim in backward] API to retrieve in-backward optimizers (#105991) 2023-07-29 01:36:25 +00:00
pipeline Format: fixing multiple string concatenation in single line (#106013) 2023-07-26 18:39:18 +00:00
rpc
tensor Make RNGStateTracker support cuda-like device (#106771) 2023-08-10 19:14:33 +00:00
__init__.py
_composable_state.py
_functional_collectives.py AsyncCollectiveTensor: dont sync on view ops (#105240) 2023-08-11 19:20:25 +00:00
_functional_collectives_impl.py AsyncCollectiveTensor: dont sync on view ops (#105240) 2023-08-11 19:20:25 +00:00
argparse_util.py
c10d_logger.py
collective_utils.py
constants.py
CONTRIBUTING.md
distributed_c10d.py
launch.py
logging_handlers.py
remote_device.py
rendezvous.py
run.py fix torchrun script for custom device (#105443) 2023-07-31 05:46:23 +00:00
utils.py