mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Loss parallel is the last piece of sequence parallelism to enable. It enables efficient distributed cross entropy computation when the input is sharded on the class dimension (in a classification problem with many classes). The implementation is via a context manager `loss_parallel`, after enabling which users can directly use `torch.nn.functional.cross_entropy` or `torch.nn.CrossEntropyLoss` without modifying other parts of their code. Here are the underlying rationales why we are going through these op replacements: 1. `nn.functional.cross_entropy` is the common method that OSS user is using for things like transformer training, to avoid changing user code, we want user to still use this function for loss calculation if they are already using it. 2. `nn.functional.cross_entropy` boils down into `aten.log_softmax` and `aten.nll_loss_foward/backward`, and DTensor now supports those ops already (#117723 #119255 #118917 #119256). They are doing computation with input *replicated* on the class dimension. 3. However when the input of this loss calculation is **sharded on the class dimension**, to run sharded computation efficiently, we need to run both `aten.log_softmax` and `aten.nll_loss_foward` with multiple all-reduce collectives **in the middle of** those aten ops. This is not possible if we are just overriding these two ops, so we need to have some way to **decompose** these two ops into smaller ops to have collectives run in the middle of these two ops. 4. We explored the existing decompositions (#118950). It seems working, except that `log_softmax_backward` and `nll_loss_backward` combined together in aten are implemented in a inefficient way, which would trigger an additional expensive collective. Recently some user also reported similar issues https://github.com/pytorch/pytorch/issues/119261. 5. Therefore, currently we are doing our own decomposition inside a context manager for sequence parallelism specifically. Once we have a better decomposition in core, we can possibly take that instead of reinventing the wheels here. Pull Request resolved: https://github.com/pytorch/pytorch/pull/119877 Approved by: https://github.com/wanchaol
63 lines
2.5 KiB
ReStructuredText
63 lines
2.5 KiB
ReStructuredText
.. role:: hidden
|
|
:class: hidden-section
|
|
|
|
Tensor Parallelism - torch.distributed.tensor.parallel
|
|
======================================================
|
|
|
|
Tensor Parallelism(TP) is built on top of the PyTorch DistributedTensor
|
|
(`DTensor <https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/README.md>`__)
|
|
and provides different parallelism styles: Colwise and Rowwise Parallelism.
|
|
|
|
.. warning ::
|
|
Tensor Parallelism APIs are experimental and subject to change.
|
|
|
|
The entrypoint to parallelize your ``nn.Module`` using Tensor Parallelism is:
|
|
|
|
.. automodule:: torch.distributed.tensor.parallel
|
|
|
|
.. currentmodule:: torch.distributed.tensor.parallel
|
|
|
|
.. autofunction:: parallelize_module
|
|
|
|
Tensor Parallelism supports the following parallel styles:
|
|
|
|
.. autoclass:: torch.distributed.tensor.parallel.ColwiseParallel
|
|
:members:
|
|
:undoc-members:
|
|
|
|
.. autoclass:: torch.distributed.tensor.parallel.RowwiseParallel
|
|
:members:
|
|
:undoc-members:
|
|
|
|
To simply configure the nn.Module's inputs and outputs with DTensor layouts
|
|
and perform necessary layout redistributions, without distribute the module
|
|
parameters to DTensors, the following ``ParallelStyle``s can be used in
|
|
the ``parallelize_plan`` when calling ``parallelize_module``:
|
|
|
|
.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleInput
|
|
:members:
|
|
:undoc-members:
|
|
|
|
.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleOutput
|
|
:members:
|
|
:undoc-members:
|
|
|
|
.. note:: when using the ``Shard(dim)`` as the input/output layouts for the above
|
|
``ParallelStyle``s, we assume the input/output activation tensors are evenly sharded on
|
|
the tensor dimension ``dim`` on the ``DeviceMesh`` that TP operates on. For instance,
|
|
since ``RowwiseParallel`` accepts input that is sharded on the last dimension, it assumes
|
|
the input tensor has already been evenly sharded on the last dimension. For the case of uneven
|
|
sharded activation tensors, one could pass in DTensor directly to the partitioned modules,
|
|
and use ``use_local_output=False`` to return DTensor after each ``ParallelStyle``, where
|
|
DTensor could track the uneven sharding information.
|
|
|
|
For models like Transformer, we recommend users to use ``ColwiseParallel``
|
|
and ``RowwiseParallel`` together in the parallelize_plan for achieve the desired
|
|
sharding for the entire model (i.e. Attention and MLP).
|
|
|
|
Parallelized cross-entropy loss computation (loss parallelism), is supported via the following context manager:
|
|
|
|
.. autofunction:: torch.distributed.tensor.parallel.loss_parallel
|
|
|
|
.. warning ::
|
|
The loss_parallel API is experimental and subject to change.
|