pytorch/torch
Rohan Varma 4e4626a23d Join-based API to support DDP uneven inputs (#42577)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42577

Closes https://github.com/pytorch/pytorch/issues/38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in https://github.com/pytorch/pytorch/issues/38174. Details are available in the RFC, but as a summary, we make the following changes:

#### Approach
1) Add a context manager `torch.nn.parallel.distributed.join`
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
4) We adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size to maintain correctness.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy

We also make some misc. changes such as adding a `rank` argument to `_distributed_broadcast_coalesced` and exposing some getters/setters on `Reducer` to support the above changes.

#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test that we divide by the effective world_size (no. of unjoined processes)

#### Performance implications

###### Trunk vs PR patched, 32 GPUs, batch size = 32
P50, forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 369/s vs 0.087 368/s

###### join(enable=True) vs without join, 32 GPUs, batch size = 32, even inputs
P50, forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.088 364/s vs 0.087 368/s

###### join(enable=False) vs without join, 32 GPUs, batch size = 32, even inputs
P50 forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 368/s vs 0.087 368/s

###### join(enable=True) with uneven inputs (offset = 2000), 32 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.183 174/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.150 213/s vs 0.087 368/s

###### join(enable=True) with uneven inputs ((offset = 2000)), 8 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.104 308/s vs 0.104 308/s
P50 backwards only batch latency & total QPS: 0.070 454/s vs 0.070 459/s

The 2 above uneven inputs benchmark was conducted 32 GPUs and 4 GPUs immediately depleting their inputs and entering "join" mode (i.e. not iterating at all), while the other 28 iterating as normal. It looks like there is a pretty significant perf hit for this case when there are uneven inputs and multi-node training. Strangely, when there is a single node (8 GPUs), this does not reproduce.

#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](https://github.com/pytorch/pytorch/issues/39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.
ghstack-source-id: 111033819

Reviewed By: mrshenli

Differential Revision: D22893859

fbshipit-source-id: dd02a7aac6c6cd968db882c62892ee1c48817fbe
2020-08-31 13:29:03 -07:00
..
_C Make ExtraFilesMap return bytes instead of str (#43241) 2020-08-28 19:11:33 -07:00
autograd [JIT] Add JIT support for torch.no_grad (#41371) 2020-08-27 15:32:57 -07:00
backends
contrib
csrc Join-based API to support DDP uneven inputs (#42577) 2020-08-31 13:29:03 -07:00
cuda Additional error checking for torch.cuda.nccl APIs. (#43247) 2020-08-26 13:50:00 -07:00
distributed Publish all_gather_object and gather_object docs (#43772) 2020-08-31 13:28:00 -07:00
distributions
fft
for_onnx
futures
fx [fx] GraphModule.src -> GraphModule.code (#43655) 2020-08-31 11:26:05 -07:00
jit Make ExtraFilesMap return bytes instead of str (#43241) 2020-08-28 19:11:33 -07:00
legacy
lib Move torch/csrc/utils/hash.h to c10/util/hash.h. (#42503) 2020-08-29 17:47:00 -07:00
linalg Add torch.linalg.norm (#42749) 2020-08-28 18:28:33 -07:00
multiprocessing
nn Join-based API to support DDP uneven inputs (#42577) 2020-08-31 13:29:03 -07:00
onnx [ONNX] Utilize ONNX shape inference for ONNX exporter (#40628) 2020-08-30 18:35:46 -07:00
optim
quantization Revert D23385091: [quant][graphmode][fx] Add top level APIs 2020-08-31 12:18:29 -07:00
sparse
testing Join-based API to support DDP uneven inputs (#42577) 2020-08-31 13:29:03 -07:00
utils Enable complex blas for ROCm. (#43744) 2020-08-30 22:43:54 -07:00
__config__.py
__future__.py
__init__.py
_appdirs.py
_classes.py
_jit_internal.py [JIT] Add JIT support for torch.no_grad (#41371) 2020-08-27 15:32:57 -07:00
_linalg_utils.py
_lobpcg.py
_lowrank.py
_namedtensor_internals.py
_ops.py
_six.py
_storage_docs.py
_tensor_docs.py [resubmit] Add amax/amin (#43819) 2020-08-31 04:54:48 -07:00
_tensor_str.py
_torch_docs.py [resubmit] Add amax/amin (#43819) 2020-08-31 04:54:48 -07:00
_utils.py [caffe2][torch] correctly re-raise Manifold StorageException 2020-08-28 11:41:10 -07:00
_utils_internal.py
_VF.py Address JIT/Mypy issue with torch._VF (#43454) 2020-08-25 09:23:54 -07:00
_vmap_internals.py
abi-check.cpp
CMakeLists.txt refactor torch/cuda/nccl.h to remove direct dependency on NCCL in libtorch_python (#42687) 2020-08-19 20:16:53 -07:00
custom_class.h Adding a version serialization type to ConvPackedParam (#43086) 2020-08-28 15:41:30 -07:00
custom_class_detail.h
extension.h
functional.py Fix type annotation errors in torch.functional (#43446) 2020-08-26 08:27:59 -07:00
hub.py Fix torch.hub for new zipfile format. (#42333) 2020-08-20 14:54:02 -07:00
library.h Reimplement per-operator selective build (#39401) 2020-08-20 19:10:02 -07:00
overrides.py Add __complex__ (#43844) 2020-08-31 11:39:41 -07:00
py.typed
quasirandom.py
random.py
README.txt
script.h
serialization.py [jit] PyTorchStreamReader::getAllRecord should omit archive name prefix (#43317) 2020-08-21 10:39:57 -07:00
storage.py
tensor.py [quant] Create nn.quantized.dynamic.EmbeddingBag (#43088) 2020-08-21 11:45:02 -07:00
types.py

Note [TH abstraction violation]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

TH/THC provide some hpp headers, which are proper C++ headers rather than
C headers.  These headers serve double duty as *internal implementation
detail* headers, whose contents should largely not be used by external
clients.

Ideally, we would not install these headers at all; instead, you should
use public functions (in headers like `THTensor.h`, NOT `THTensor.hpp`)
to manipulate these structs.  However, there are a few places
in torch/csrc where we violate this abstraction.  They are marked with
a pointer to this note.  Each of those sites will have to be refactored
when we refactor the guts of THTensor and related structures.