2022-11-15 22:51:33 +00:00
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
Allow parallelize_module to get device_mesh from ambient context (#134247)
This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one.
Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model.
(The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.)
For example:
```
class FeedForward(nn.Module):
def __init__(self, config: TransformerArgs) -> None:
super().__init__()
w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w1 = parallelize_module(w1, Colwise)
self.w2 = parallelize_module(w2, Rowwise)
self.w3 = parallelize_module(w3, Colwise)
def forward(self, x: Tensor) -> Tensor:
y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x))
# y is a DTensor with Partial placement; we can return it as is.
return y
# Or we can convert it to Replicate -- there is modeling flexibility here.
return y.redistribute(Replicate())
with device_mesh:
model = FeedForward(config)
# Now model is a model parallelized onto device_mesh
y = model(x)
```
The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context.
Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134247
Approved by: https://github.com/tianyu-l
2024-10-08 19:49:33 +00:00
|
|
|
import warnings
|
2024-04-02 14:00:25 +00:00
|
|
|
from fnmatch import fnmatch
|
2025-01-18 06:11:10 +00:00
|
|
|
from typing import Optional, Union
|
2022-12-01 02:15:06 +00:00
|
|
|
|
2022-11-15 22:51:33 +00:00
|
|
|
import torch
|
[TP] Refactor Parallel Style to make it more usable (#111160)
One thing we find it challenging for users is that we don't want to expose the concept of prepare_input and prepare_out to users since there are so many func names for users to select from which is quite confusing. On the other hand, the colwise and rowwise parallel always need input(out) and output(in) to be certain layout so we can somehow simplify the logic here and make it more usable.
So we added three public attributes to the parallelStyle here and the code logic is like:
```python
class ParallelStyle(ABC):
"""
The parallel style user wants the module or submodule to be parallelized.
We can add more in future, but this seems sufficient for immediate needs. Users can extend this class to build their own parallel style with customized input/output preparations.
"""
input_layouts: Union[placement, Tuple[placement]]
output_layouts: Union[placement, Tuple[placement]]
use_local: bool
class RowwiseParallel(ParallelStyle):
"""
Partitioning the row of a module. We assume the input to be a sharded DTensor and output to be a replicate Tensor.
"""
def __init__(self):
super().__init__(input_layouts=Shard(-1), output_layouts=Replicate(), use_local=True)
Class ColwiseParallel(ParallelStyle):
"""
Partitioning the column of a module. We assume the input to be a Replicated DTensor and output to be a sharded DTensor.
"""
def __init__(self):
super().__init__(input_layouts=Replicate(), output_layouts=Shard(-1), use_local=True)
# For the case of Sequence parallel, users just set different input_shard, Shard(0) or Shard(1) instead of Replicate()
Class PrepareModuleInput(ParallelStyle):
"""
Only used to specify the input distribute spec for a module.
"""
def __init__(self):
super().__init__(input_layouts=Shard(0), output_layouts=Replicate(), use_local=False)
Class PrepareModuleOutput(ParallelStyle):
"""
Only used to specify the output distribute spec for a module.
"""
def __init__(self):
super().__init__(input_layouts=Replicate(), output_layouts=Shard(0), use_local=True)
parallelize_plan = {
"embedding": ColwiseParallel(output_shard=Replicate()),
"attn": PrepareModuleInput(),
"attn.w1": ColwiseParallel(),
"attn.w2": ColwiseParallel(),
"attn.w3": ColwiseParallel(),
"attn.wo": RowwiseParallel(),
}
parallelize_module(
module=block, # this can be a submodule or module
device_mesh=mesh['tp'],
parallelize_plan=parallelize_plan,
)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111160
Approved by: https://github.com/wanchaol
2023-10-13 23:53:15 +00:00
|
|
|
import torch.nn as nn
|
Allow parallelize_module to get device_mesh from ambient context (#134247)
This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one.
Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model.
(The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.)
For example:
```
class FeedForward(nn.Module):
def __init__(self, config: TransformerArgs) -> None:
super().__init__()
w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w1 = parallelize_module(w1, Colwise)
self.w2 = parallelize_module(w2, Rowwise)
self.w3 = parallelize_module(w3, Colwise)
def forward(self, x: Tensor) -> Tensor:
y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x))
# y is a DTensor with Partial placement; we can return it as is.
return y
# Or we can convert it to Replicate -- there is modeling flexibility here.
return y.redistribute(Replicate())
with device_mesh:
model = FeedForward(config)
# Now model is a model parallelized onto device_mesh
y = model(x)
```
The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context.
Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134247
Approved by: https://github.com/tianyu-l
2024-10-08 19:49:33 +00:00
|
|
|
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
|
2024-08-19 05:00:19 +00:00
|
|
|
from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
|
|
|
|
|
from torch.distributed.tensor.parallel.style import ParallelStyle
|
2022-11-15 22:51:33 +00:00
|
|
|
|
|
|
|
|
|
2024-11-27 23:28:16 +00:00
|
|
|
__all__ = ["parallelize_module"]
|
2022-11-15 22:51:33 +00:00
|
|
|
|
|
|
|
|
|
2022-11-22 03:05:50 +00:00
|
|
|
def parallelize_module( # type: ignore[return]
|
|
|
|
|
module: nn.Module,
|
Allow parallelize_module to get device_mesh from ambient context (#134247)
This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one.
Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model.
(The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.)
For example:
```
class FeedForward(nn.Module):
def __init__(self, config: TransformerArgs) -> None:
super().__init__()
w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w1 = parallelize_module(w1, Colwise)
self.w2 = parallelize_module(w2, Rowwise)
self.w3 = parallelize_module(w3, Colwise)
def forward(self, x: Tensor) -> Tensor:
y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x))
# y is a DTensor with Partial placement; we can return it as is.
return y
# Or we can convert it to Replicate -- there is modeling flexibility here.
return y.redistribute(Replicate())
with device_mesh:
model = FeedForward(config)
# Now model is a model parallelized onto device_mesh
y = model(x)
```
The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context.
Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134247
Approved by: https://github.com/tianyu-l
2024-10-08 19:49:33 +00:00
|
|
|
device_mesh: Optional[DeviceMesh] = None,
|
2025-01-18 06:11:10 +00:00
|
|
|
parallelize_plan: Optional[Union[ParallelStyle, dict[str, ParallelStyle]]] = None,
|
2024-12-30 23:01:05 +00:00
|
|
|
*,
|
|
|
|
|
src_data_rank: Optional[int] = 0,
|
2022-11-22 03:05:50 +00:00
|
|
|
) -> nn.Module:
|
|
|
|
|
"""
|
[TP] fully rewrite Tensor Parallel APIs (#114732)
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs
supposed to be a very thin-wrapper to DTensor APIs, but the current
implementation got too messy and buggy. It's really hard to debug what
went wrong when using it. It's crucially important for advanced users or
developers to understand the API and its implementation easily without
going through all different types of functions and utils, so that
they could trust what happen under the hood.
In particular this PR:
* Make ParallelStyle to be a real contract API for parallelize_module to
take, each concrete ParallelStyle only needs to implement `apply` to
apply the sharding to nn.Module, remove all non-necessary fields. This
also enable easier ParallelStyle authoring going forward.
* Keep the ColwiseParallel and RowwiseParallel public interface, but
refactor them in a way that makes the parameter sharding, inputs and
outputs handling lives within the style itself, so that it's easy to
understand how Linear/Embedding layers are sharded and how the inputs/outputs
transformations are performed.
* remove all those private _prepare_input/_prepare_output_fn fields for
both ColwiseParallel/RowwiseParallel. Since we throw deprecation
messages in nightly for a while and TP is on prototype release, the
fields are also private, it should be safe to remove them
* Refactor the recently landed PrepareModuleInput/Output style, change
output_layouts to desired_input/output_layouts, group
the function inside the style itself, no default arguments for these
two styles and user need to specify them to think about the sharding
layouts. Fixed bugs about not handling
`use_local_output` flag.
* Make default arguments be None instead of Placement object, this is
standard python practice to not have custom object instance as default
argument
* Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel
style, all prepare input/output functions) as we throw deprecation
msgs for a while, and in the progress of removing all of them from the tests.
* throw deprecation warning for `tp_mesh_dim` as we recomemnd use device
mesh slice/indexing instead of manually specify mesh dim
* Rewrite all documentations for every ParallelStyle and make the
documentation more clear about what each style is doing
TODOs:
* Rewrite TP tests to adjust for the changes we have in this PR
* add more tests to guard the bug fixes
Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114732
Approved by: https://github.com/wz337, https://github.com/fduwjj
2023-12-02 04:53:26 +00:00
|
|
|
Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.
|
2023-11-09 19:10:15 +00:00
|
|
|
|
|
|
|
|
We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains
|
2023-02-13 18:15:01 +00:00
|
|
|
:class:`ParallelStyle`, which indicates how user wants the module or sub_module
|
|
|
|
|
to be parallelized.
|
|
|
|
|
|
2023-03-27 21:13:40 +00:00
|
|
|
User can also specify different parallel style per module fully qualified name (FQN).
|
2023-12-15 04:39:25 +00:00
|
|
|
|
|
|
|
|
Note that ``parallelize_module`` only accepts a 1-D :class:`DeviceMesh`, if you have a 2-D or N-D :class:`DeviceMesh`,
|
|
|
|
|
slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. ``device_mesh[\"tp\"]``)
|
2022-11-15 22:51:33 +00:00
|
|
|
|
2022-11-22 03:05:50 +00:00
|
|
|
Args:
|
2022-11-23 05:29:53 +00:00
|
|
|
module (:class:`nn.Module`):
|
|
|
|
|
Module to be parallelized.
|
Allow parallelize_module to get device_mesh from ambient context (#134247)
This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one.
Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model.
(The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.)
For example:
```
class FeedForward(nn.Module):
def __init__(self, config: TransformerArgs) -> None:
super().__init__()
w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w1 = parallelize_module(w1, Colwise)
self.w2 = parallelize_module(w2, Rowwise)
self.w3 = parallelize_module(w3, Colwise)
def forward(self, x: Tensor) -> Tensor:
y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x))
# y is a DTensor with Partial placement; we can return it as is.
return y
# Or we can convert it to Replicate -- there is modeling flexibility here.
return y.redistribute(Replicate())
with device_mesh:
model = FeedForward(config)
# Now model is a model parallelized onto device_mesh
y = model(x)
```
The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context.
Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134247
Approved by: https://github.com/tianyu-l
2024-10-08 19:49:33 +00:00
|
|
|
device_mesh (:class:`DeviceMesh`, optional):
|
|
|
|
|
Object which describes the mesh topology of devices for the DTensor.
|
|
|
|
|
If not specified, the call must be under a DeviceMesh context.
|
|
|
|
|
parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]], optional):
|
2022-11-22 03:05:50 +00:00
|
|
|
The plan used to parallelize the module. It can be either a
|
Allow parallelize_module to get device_mesh from ambient context (#134247)
This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one.
Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model.
(The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.)
For example:
```
class FeedForward(nn.Module):
def __init__(self, config: TransformerArgs) -> None:
super().__init__()
w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w1 = parallelize_module(w1, Colwise)
self.w2 = parallelize_module(w2, Rowwise)
self.w3 = parallelize_module(w3, Colwise)
def forward(self, x: Tensor) -> Tensor:
y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x))
# y is a DTensor with Partial placement; we can return it as is.
return y
# Or we can convert it to Replicate -- there is modeling flexibility here.
return y.redistribute(Replicate())
with device_mesh:
model = FeedForward(config)
# Now model is a model parallelized onto device_mesh
y = model(x)
```
The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context.
Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134247
Approved by: https://github.com/tianyu-l
2024-10-08 19:49:33 +00:00
|
|
|
:class:`ParallelStyle` object which contains how we prepare
|
|
|
|
|
input/output for Tensor Parallelism or it can be a dict of module
|
|
|
|
|
FQN and its corresponding :class:`ParallelStyle` object. If not
|
|
|
|
|
specified, the call will do nothing at the moment.
|
2024-12-30 23:01:05 +00:00
|
|
|
Keyword args:
|
|
|
|
|
src_data_rank (int, optional): the rank of the source data for the logical/global tensor, it is used by
|
|
|
|
|
:meth:`distribute_tensor` to scatter/broadcast the shards/replicas to other ranks. By default,
|
|
|
|
|
we use ``group_rank=0`` on each DeviceMesh dimension as the source data to preserve the single-device
|
|
|
|
|
semantic. If passing ``None`` explicitly, :meth:`parallelize_module` simply uses its local data instead
|
|
|
|
|
of trying to preserve the single-device semantic via scatter/broadcast. Default: 0
|
2022-11-22 03:05:50 +00:00
|
|
|
Return:
|
|
|
|
|
A :class:`nn.Module` object parallelized.
|
|
|
|
|
|
|
|
|
|
Example::
|
|
|
|
|
>>> # xdoctest: +SKIP("distributed")
|
2023-10-16 03:11:16 +00:00
|
|
|
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
|
2023-12-16 01:54:52 +00:00
|
|
|
>>> from torch.distributed.device_mesh import init_device_mesh
|
2022-11-22 03:05:50 +00:00
|
|
|
>>>
|
|
|
|
|
>>> # Define the module.
|
|
|
|
|
>>> m = Model(...)
|
2023-12-16 01:54:52 +00:00
|
|
|
>>> tp_mesh = init_device_mesh("cuda", (8,))
|
|
|
|
|
>>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()})
|
2022-11-22 03:05:50 +00:00
|
|
|
>>>
|
|
|
|
|
|
2023-12-15 04:39:25 +00:00
|
|
|
.. note:: For complex module architecture like Attention, MLP layers, we recommend composing
|
|
|
|
|
different ParallelStyles together (i.e. ``ColwiseParallel`` and ``RowwiseParallel``) and pass
|
|
|
|
|
as a parallelize_plan, to achieves the desired sharding computation.
|
2022-11-22 03:05:50 +00:00
|
|
|
"""
|
2023-05-24 21:21:07 +00:00
|
|
|
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")
|
|
|
|
|
|
Allow parallelize_module to get device_mesh from ambient context (#134247)
This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one.
Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model.
(The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.)
For example:
```
class FeedForward(nn.Module):
def __init__(self, config: TransformerArgs) -> None:
super().__init__()
w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w1 = parallelize_module(w1, Colwise)
self.w2 = parallelize_module(w2, Rowwise)
self.w3 = parallelize_module(w3, Colwise)
def forward(self, x: Tensor) -> Tensor:
y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x))
# y is a DTensor with Partial placement; we can return it as is.
return y
# Or we can convert it to Replicate -- there is modeling flexibility here.
return y.redistribute(Replicate())
with device_mesh:
model = FeedForward(config)
# Now model is a model parallelized onto device_mesh
y = model(x)
```
The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context.
Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134247
Approved by: https://github.com/tianyu-l
2024-10-08 19:49:33 +00:00
|
|
|
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
|
2024-03-08 05:54:33 +00:00
|
|
|
_validate_tp_mesh_dim(device_mesh)
|
|
|
|
|
|
Allow parallelize_module to get device_mesh from ambient context (#134247)
This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one.
Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model.
(The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.)
For example:
```
class FeedForward(nn.Module):
def __init__(self, config: TransformerArgs) -> None:
super().__init__()
w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w1 = parallelize_module(w1, Colwise)
self.w2 = parallelize_module(w2, Rowwise)
self.w3 = parallelize_module(w3, Colwise)
def forward(self, x: Tensor) -> Tensor:
y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x))
# y is a DTensor with Partial placement; we can return it as is.
return y
# Or we can convert it to Replicate -- there is modeling flexibility here.
return y.redistribute(Replicate())
with device_mesh:
model = FeedForward(config)
# Now model is a model parallelized onto device_mesh
y = model(x)
```
The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context.
Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134247
Approved by: https://github.com/tianyu-l
2024-10-08 19:49:33 +00:00
|
|
|
if parallelize_plan is None:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"No parallelize_plan is provided and auto-parallel is not supported "
|
|
|
|
|
"at the moment, so this parallelize_module call will do nothing."
|
|
|
|
|
)
|
|
|
|
|
return module
|
|
|
|
|
|
2024-11-27 23:28:16 +00:00
|
|
|
# note: The RNG tracker will be initialized in distribute_tensor() call if it hasn't
|
|
|
|
|
# been initialized.
|
2023-06-29 20:39:14 +00:00
|
|
|
|
2022-11-22 03:05:50 +00:00
|
|
|
if isinstance(parallelize_plan, ParallelStyle):
|
2024-12-30 23:01:05 +00:00
|
|
|
parallelize_plan.src_data_rank = src_data_rank
|
[TP] fully rewrite Tensor Parallel APIs (#114732)
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs
supposed to be a very thin-wrapper to DTensor APIs, but the current
implementation got too messy and buggy. It's really hard to debug what
went wrong when using it. It's crucially important for advanced users or
developers to understand the API and its implementation easily without
going through all different types of functions and utils, so that
they could trust what happen under the hood.
In particular this PR:
* Make ParallelStyle to be a real contract API for parallelize_module to
take, each concrete ParallelStyle only needs to implement `apply` to
apply the sharding to nn.Module, remove all non-necessary fields. This
also enable easier ParallelStyle authoring going forward.
* Keep the ColwiseParallel and RowwiseParallel public interface, but
refactor them in a way that makes the parameter sharding, inputs and
outputs handling lives within the style itself, so that it's easy to
understand how Linear/Embedding layers are sharded and how the inputs/outputs
transformations are performed.
* remove all those private _prepare_input/_prepare_output_fn fields for
both ColwiseParallel/RowwiseParallel. Since we throw deprecation
messages in nightly for a while and TP is on prototype release, the
fields are also private, it should be safe to remove them
* Refactor the recently landed PrepareModuleInput/Output style, change
output_layouts to desired_input/output_layouts, group
the function inside the style itself, no default arguments for these
two styles and user need to specify them to think about the sharding
layouts. Fixed bugs about not handling
`use_local_output` flag.
* Make default arguments be None instead of Placement object, this is
standard python practice to not have custom object instance as default
argument
* Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel
style, all prepare input/output functions) as we throw deprecation
msgs for a while, and in the progress of removing all of them from the tests.
* throw deprecation warning for `tp_mesh_dim` as we recomemnd use device
mesh slice/indexing instead of manually specify mesh dim
* Rewrite all documentations for every ParallelStyle and make the
documentation more clear about what each style is doing
TODOs:
* Rewrite TP tests to adjust for the changes we have in this PR
* add more tests to guard the bug fixes
Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114732
Approved by: https://github.com/wz337, https://github.com/fduwjj
2023-12-02 04:53:26 +00:00
|
|
|
return parallelize_plan._apply(module, device_mesh)
|
2022-11-22 03:05:50 +00:00
|
|
|
elif isinstance(parallelize_plan, dict):
|
|
|
|
|
for module_path, parallelize_style in parallelize_plan.items():
|
2024-04-02 14:00:25 +00:00
|
|
|
path_splits = module_path.split(".")
|
|
|
|
|
if len(path_splits) == 0:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Expect module path to be non-empty, but got empty string!"
|
|
|
|
|
)
|
|
|
|
|
while path_splits:
|
|
|
|
|
atom = path_splits.pop(0)
|
|
|
|
|
matched_children = filter(
|
|
|
|
|
# `t[0]` is child name
|
2024-06-18 15:21:43 +00:00
|
|
|
lambda t: fnmatch(t[0], atom),
|
|
|
|
|
module.named_children(),
|
2024-04-02 14:00:25 +00:00
|
|
|
)
|
|
|
|
|
# apply the plan to all matched submodules
|
|
|
|
|
for _, submodule in matched_children:
|
|
|
|
|
if path_splits:
|
|
|
|
|
# we haven't reached the leaf, apply in dict style
|
2024-06-18 15:21:43 +00:00
|
|
|
leaf_path = ".".join(
|
|
|
|
|
path_splits
|
|
|
|
|
) # rest of the path after `atom`
|
|
|
|
|
parallelize_module(
|
2024-12-30 23:01:05 +00:00
|
|
|
submodule,
|
|
|
|
|
device_mesh,
|
|
|
|
|
{leaf_path: parallelize_style},
|
|
|
|
|
src_data_rank=src_data_rank,
|
2024-06-18 15:21:43 +00:00
|
|
|
)
|
2024-04-02 14:00:25 +00:00
|
|
|
else:
|
|
|
|
|
# otherwise, directly apply style to this submodule
|
2024-12-30 23:01:05 +00:00
|
|
|
parallelize_module(
|
|
|
|
|
submodule,
|
|
|
|
|
device_mesh,
|
|
|
|
|
parallelize_style,
|
|
|
|
|
src_data_rank=src_data_rank,
|
|
|
|
|
)
|
2023-01-26 05:19:31 +00:00
|
|
|
return module
|
2022-11-15 22:51:33 +00:00
|
|
|
else:
|
2024-04-02 14:00:25 +00:00
|
|
|
raise TypeError( # pyre-ignore[7]
|
2022-11-23 05:29:53 +00:00
|
|
|
"Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"
|
|
|
|
|
f" parallelize_plan, {type(parallelize_plan)} found!"
|
2022-11-22 03:05:50 +00:00
|
|
|
)
|