pytorch/torch/distributed/tensor/parallel/api.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

130 lines
5.9 KiB
Python
Raw Permalink Normal View History

# Copyright (c) Meta Platforms, Inc. and affiliates
Allow parallelize_module to get device_mesh from ambient context (#134247) This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one. Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model. (The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.) For example: ``` class FeedForward(nn.Module): def __init__(self, config: TransformerArgs) -> None: super().__init__() w1 = nn.Linear(config.dim, config.hidden_dim, bias=False) w2 = nn.Linear(config.hidden_dim, config.dim, bias=False) w3 = nn.Linear(config.dim, config.hidden_dim, bias=False) self.w1 = parallelize_module(w1, Colwise) self.w2 = parallelize_module(w2, Rowwise) self.w3 = parallelize_module(w3, Colwise) def forward(self, x: Tensor) -> Tensor: y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x)) # y is a DTensor with Partial placement; we can return it as is. return y # Or we can convert it to Replicate -- there is modeling flexibility here. return y.redistribute(Replicate()) with device_mesh: model = FeedForward(config) # Now model is a model parallelized onto device_mesh y = model(x) ``` The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context. Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134247 Approved by: https://github.com/tianyu-l
2024-10-08 19:49:33 +00:00
import warnings
from fnmatch import fnmatch
from typing import Optional, Union
import torch
[TP] Refactor Parallel Style to make it more usable (#111160) One thing we find it challenging for users is that we don't want to expose the concept of prepare_input and prepare_out to users since there are so many func names for users to select from which is quite confusing. On the other hand, the colwise and rowwise parallel always need input(out) and output(in) to be certain layout so we can somehow simplify the logic here and make it more usable. So we added three public attributes to the parallelStyle here and the code logic is like: ```python class ParallelStyle(ABC): """ The parallel style user wants the module or submodule to be parallelized. We can add more in future, but this seems sufficient for immediate needs. Users can extend this class to build their own parallel style with customized input/output preparations. """ input_layouts: Union[placement, Tuple[placement]] output_layouts: Union[placement, Tuple[placement]] use_local: bool class RowwiseParallel(ParallelStyle): """ Partitioning the row of a module. We assume the input to be a sharded DTensor and output to be a replicate Tensor. """ def __init__(self): super().__init__(input_layouts=Shard(-1), output_layouts=Replicate(), use_local=True) Class ColwiseParallel(ParallelStyle): """ Partitioning the column of a module. We assume the input to be a Replicated DTensor and output to be a sharded DTensor. """ def __init__(self): super().__init__(input_layouts=Replicate(), output_layouts=Shard(-1), use_local=True) # For the case of Sequence parallel, users just set different input_shard, Shard(0) or Shard(1) instead of Replicate() Class PrepareModuleInput(ParallelStyle): """ Only used to specify the input distribute spec for a module. """ def __init__(self): super().__init__(input_layouts=Shard(0), output_layouts=Replicate(), use_local=False) Class PrepareModuleOutput(ParallelStyle): """ Only used to specify the output distribute spec for a module. """ def __init__(self): super().__init__(input_layouts=Replicate(), output_layouts=Shard(0), use_local=True) parallelize_plan = { "embedding": ColwiseParallel(output_shard=Replicate()), "attn": PrepareModuleInput(), "attn.w1": ColwiseParallel(), "attn.w2": ColwiseParallel(), "attn.w3": ColwiseParallel(), "attn.wo": RowwiseParallel(), } parallelize_module( module=block, # this can be a submodule or module device_mesh=mesh['tp'], parallelize_plan=parallelize_plan, ) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/111160 Approved by: https://github.com/wanchaol
2023-10-13 23:53:15 +00:00
import torch.nn as nn
Allow parallelize_module to get device_mesh from ambient context (#134247) This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one. Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model. (The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.) For example: ``` class FeedForward(nn.Module): def __init__(self, config: TransformerArgs) -> None: super().__init__() w1 = nn.Linear(config.dim, config.hidden_dim, bias=False) w2 = nn.Linear(config.hidden_dim, config.dim, bias=False) w3 = nn.Linear(config.dim, config.hidden_dim, bias=False) self.w1 = parallelize_module(w1, Colwise) self.w2 = parallelize_module(w2, Rowwise) self.w3 = parallelize_module(w3, Colwise) def forward(self, x: Tensor) -> Tensor: y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x)) # y is a DTensor with Partial placement; we can return it as is. return y # Or we can convert it to Replicate -- there is modeling flexibility here. return y.redistribute(Replicate()) with device_mesh: model = FeedForward(config) # Now model is a model parallelized onto device_mesh y = model(x) ``` The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context. Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134247 Approved by: https://github.com/tianyu-l
2024-10-08 19:49:33 +00:00
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
from torch.distributed.tensor.parallel.style import ParallelStyle
[dtensor][random][tp] remove the adhoc DTensor RNG tracker TensorParallelRNGTracker since it does not match FSDP2+TP (#141220) **Summary** The ad-hoc DTensor RNG tracker was used to mimic Megatron DDP+TP RNG behavior but it turns out not compatible with PyTorch Distributed FSDP2+TP so we decide to deprecate it and use `OffsetBasedRNGTracker` to replace, which follows the SPMD semantics (replicas get the same random sampling result, shards get different results). **Motivation** `TensorParallelRNGTracker` was designed for DDP+TP where the random operators produce the same result along the data parallel mesh dimension and different results along the tensor parallel dimension. However this does not apply to the new FSDP+TP composable combination where the model weights are sharded along data parallel mesh dimension as well. Therefore we decide to remove this outdated RNG tracker type for now. If users have demands for exact match between PyTorch Distributed and Megatron on Random Number generation result, feel free to file an issue. **Impact** `TensorParallelRNGTracker` was only used when Tensor Parallel is used (i.e. calling `parallelize_module`). For non-FSDP users, the "replicas get the same random numbers and shards get different ones" remains unchanged. Unlike `TensorParallelRNGTracker` which sets different seeds (`base_seed + 2718 + TP_rank`) within the TP group, DTensor now sets the same seed (default value is 1234 but users can call `torch.distributed.tensor._random.manual_seed` to modify) on all ranks but choose the right RNG offset based on DTensor placements to enforce the "replicas get the same random numbers and shards get different ones" invariant. For FSDP2 users, improvement should be observed in a way that DTensor sharded within DP group now gets different random number sampling which `TensorParallelRNGTracker` failed to do, though we're not sure how much this change will improve the eventual training loss convergence. **Test** 1-d model weight meta init: `pytest test/distributed/_tensor/test_random_ops.py -s -k test_tp_model_meta_init` 2-d model weight meta init: `pytest test/distributed/_tensor/test_random_ops.py -s -k test_fsdp_tp_model_meta_init` TP model weight init test: `pytest test/distributed/tensor/parallel/test_tp_random_state.py` FSDP+TP model weight init test: `pytest test/distributed/_composable/fsdp/test_fully_shard_init.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/141220 Approved by: https://github.com/wconstab ghstack dependencies: #141731
2024-11-27 23:28:16 +00:00
__all__ = ["parallelize_module"]
def parallelize_module( # type: ignore[return]
module: nn.Module,
Allow parallelize_module to get device_mesh from ambient context (#134247) This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one. Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model. (The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.) For example: ``` class FeedForward(nn.Module): def __init__(self, config: TransformerArgs) -> None: super().__init__() w1 = nn.Linear(config.dim, config.hidden_dim, bias=False) w2 = nn.Linear(config.hidden_dim, config.dim, bias=False) w3 = nn.Linear(config.dim, config.hidden_dim, bias=False) self.w1 = parallelize_module(w1, Colwise) self.w2 = parallelize_module(w2, Rowwise) self.w3 = parallelize_module(w3, Colwise) def forward(self, x: Tensor) -> Tensor: y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x)) # y is a DTensor with Partial placement; we can return it as is. return y # Or we can convert it to Replicate -- there is modeling flexibility here. return y.redistribute(Replicate()) with device_mesh: model = FeedForward(config) # Now model is a model parallelized onto device_mesh y = model(x) ``` The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context. Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134247 Approved by: https://github.com/tianyu-l
2024-10-08 19:49:33 +00:00
device_mesh: Optional[DeviceMesh] = None,
parallelize_plan: Optional[Union[ParallelStyle, dict[str, ParallelStyle]]] = None,
*,
src_data_rank: Optional[int] = 0,
) -> nn.Module:
"""
[TP] fully rewrite Tensor Parallel APIs (#114732) This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs supposed to be a very thin-wrapper to DTensor APIs, but the current implementation got too messy and buggy. It's really hard to debug what went wrong when using it. It's crucially important for advanced users or developers to understand the API and its implementation easily without going through all different types of functions and utils, so that they could trust what happen under the hood. In particular this PR: * Make ParallelStyle to be a real contract API for parallelize_module to take, each concrete ParallelStyle only needs to implement `apply` to apply the sharding to nn.Module, remove all non-necessary fields. This also enable easier ParallelStyle authoring going forward. * Keep the ColwiseParallel and RowwiseParallel public interface, but refactor them in a way that makes the parameter sharding, inputs and outputs handling lives within the style itself, so that it's easy to understand how Linear/Embedding layers are sharded and how the inputs/outputs transformations are performed. * remove all those private _prepare_input/_prepare_output_fn fields for both ColwiseParallel/RowwiseParallel. Since we throw deprecation messages in nightly for a while and TP is on prototype release, the fields are also private, it should be safe to remove them * Refactor the recently landed PrepareModuleInput/Output style, change output_layouts to desired_input/output_layouts, group the function inside the style itself, no default arguments for these two styles and user need to specify them to think about the sharding layouts. Fixed bugs about not handling `use_local_output` flag. * Make default arguments be None instead of Placement object, this is standard python practice to not have custom object instance as default argument * Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel style, all prepare input/output functions) as we throw deprecation msgs for a while, and in the progress of removing all of them from the tests. * throw deprecation warning for `tp_mesh_dim` as we recomemnd use device mesh slice/indexing instead of manually specify mesh dim * Rewrite all documentations for every ParallelStyle and make the documentation more clear about what each style is doing TODOs: * Rewrite TP tests to adjust for the changes we have in this PR * add more tests to guard the bug fixes Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183) Pull Request resolved: https://github.com/pytorch/pytorch/pull/114732 Approved by: https://github.com/wz337, https://github.com/fduwjj
2023-12-02 04:53:26 +00:00
Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.
We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains
:class:`ParallelStyle`, which indicates how user wants the module or sub_module
to be parallelized.
User can also specify different parallel style per module fully qualified name (FQN).
Note that ``parallelize_module`` only accepts a 1-D :class:`DeviceMesh`, if you have a 2-D or N-D :class:`DeviceMesh`,
slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. ``device_mesh[\"tp\"]``)
Args:
module (:class:`nn.Module`):
Module to be parallelized.
Allow parallelize_module to get device_mesh from ambient context (#134247) This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one. Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model. (The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.) For example: ``` class FeedForward(nn.Module): def __init__(self, config: TransformerArgs) -> None: super().__init__() w1 = nn.Linear(config.dim, config.hidden_dim, bias=False) w2 = nn.Linear(config.hidden_dim, config.dim, bias=False) w3 = nn.Linear(config.dim, config.hidden_dim, bias=False) self.w1 = parallelize_module(w1, Colwise) self.w2 = parallelize_module(w2, Rowwise) self.w3 = parallelize_module(w3, Colwise) def forward(self, x: Tensor) -> Tensor: y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x)) # y is a DTensor with Partial placement; we can return it as is. return y # Or we can convert it to Replicate -- there is modeling flexibility here. return y.redistribute(Replicate()) with device_mesh: model = FeedForward(config) # Now model is a model parallelized onto device_mesh y = model(x) ``` The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context. Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134247 Approved by: https://github.com/tianyu-l
2024-10-08 19:49:33 +00:00
device_mesh (:class:`DeviceMesh`, optional):
Object which describes the mesh topology of devices for the DTensor.
If not specified, the call must be under a DeviceMesh context.
parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]], optional):
The plan used to parallelize the module. It can be either a
Allow parallelize_module to get device_mesh from ambient context (#134247) This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one. Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model. (The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.) For example: ``` class FeedForward(nn.Module): def __init__(self, config: TransformerArgs) -> None: super().__init__() w1 = nn.Linear(config.dim, config.hidden_dim, bias=False) w2 = nn.Linear(config.hidden_dim, config.dim, bias=False) w3 = nn.Linear(config.dim, config.hidden_dim, bias=False) self.w1 = parallelize_module(w1, Colwise) self.w2 = parallelize_module(w2, Rowwise) self.w3 = parallelize_module(w3, Colwise) def forward(self, x: Tensor) -> Tensor: y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x)) # y is a DTensor with Partial placement; we can return it as is. return y # Or we can convert it to Replicate -- there is modeling flexibility here. return y.redistribute(Replicate()) with device_mesh: model = FeedForward(config) # Now model is a model parallelized onto device_mesh y = model(x) ``` The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context. Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134247 Approved by: https://github.com/tianyu-l
2024-10-08 19:49:33 +00:00
:class:`ParallelStyle` object which contains how we prepare
input/output for Tensor Parallelism or it can be a dict of module
FQN and its corresponding :class:`ParallelStyle` object. If not
specified, the call will do nothing at the moment.
Keyword args:
src_data_rank (int, optional): the rank of the source data for the logical/global tensor, it is used by
:meth:`distribute_tensor` to scatter/broadcast the shards/replicas to other ranks. By default,
we use ``group_rank=0`` on each DeviceMesh dimension as the source data to preserve the single-device
semantic. If passing ``None`` explicitly, :meth:`parallelize_module` simply uses its local data instead
of trying to preserve the single-device semantic via scatter/broadcast. Default: 0
Return:
A :class:`nn.Module` object parallelized.
Example::
>>> # xdoctest: +SKIP("distributed")
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>>
>>> # Define the module.
>>> m = Model(...)
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()})
>>>
.. note:: For complex module architecture like Attention, MLP layers, we recommend composing
different ParallelStyles together (i.e. ``ColwiseParallel`` and ``RowwiseParallel``) and pass
as a parallelize_plan, to achieves the desired sharding computation.
"""
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")
Allow parallelize_module to get device_mesh from ambient context (#134247) This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one. Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model. (The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.) For example: ``` class FeedForward(nn.Module): def __init__(self, config: TransformerArgs) -> None: super().__init__() w1 = nn.Linear(config.dim, config.hidden_dim, bias=False) w2 = nn.Linear(config.hidden_dim, config.dim, bias=False) w3 = nn.Linear(config.dim, config.hidden_dim, bias=False) self.w1 = parallelize_module(w1, Colwise) self.w2 = parallelize_module(w2, Rowwise) self.w3 = parallelize_module(w3, Colwise) def forward(self, x: Tensor) -> Tensor: y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x)) # y is a DTensor with Partial placement; we can return it as is. return y # Or we can convert it to Replicate -- there is modeling flexibility here. return y.redistribute(Replicate()) with device_mesh: model = FeedForward(config) # Now model is a model parallelized onto device_mesh y = model(x) ``` The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context. Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134247 Approved by: https://github.com/tianyu-l
2024-10-08 19:49:33 +00:00
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
_validate_tp_mesh_dim(device_mesh)
Allow parallelize_module to get device_mesh from ambient context (#134247) This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one. Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model. (The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.) For example: ``` class FeedForward(nn.Module): def __init__(self, config: TransformerArgs) -> None: super().__init__() w1 = nn.Linear(config.dim, config.hidden_dim, bias=False) w2 = nn.Linear(config.hidden_dim, config.dim, bias=False) w3 = nn.Linear(config.dim, config.hidden_dim, bias=False) self.w1 = parallelize_module(w1, Colwise) self.w2 = parallelize_module(w2, Rowwise) self.w3 = parallelize_module(w3, Colwise) def forward(self, x: Tensor) -> Tensor: y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x)) # y is a DTensor with Partial placement; we can return it as is. return y # Or we can convert it to Replicate -- there is modeling flexibility here. return y.redistribute(Replicate()) with device_mesh: model = FeedForward(config) # Now model is a model parallelized onto device_mesh y = model(x) ``` The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context. Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134247 Approved by: https://github.com/tianyu-l
2024-10-08 19:49:33 +00:00
if parallelize_plan is None:
warnings.warn(
"No parallelize_plan is provided and auto-parallel is not supported "
"at the moment, so this parallelize_module call will do nothing."
)
return module
[dtensor][random][tp] remove the adhoc DTensor RNG tracker TensorParallelRNGTracker since it does not match FSDP2+TP (#141220) **Summary** The ad-hoc DTensor RNG tracker was used to mimic Megatron DDP+TP RNG behavior but it turns out not compatible with PyTorch Distributed FSDP2+TP so we decide to deprecate it and use `OffsetBasedRNGTracker` to replace, which follows the SPMD semantics (replicas get the same random sampling result, shards get different results). **Motivation** `TensorParallelRNGTracker` was designed for DDP+TP where the random operators produce the same result along the data parallel mesh dimension and different results along the tensor parallel dimension. However this does not apply to the new FSDP+TP composable combination where the model weights are sharded along data parallel mesh dimension as well. Therefore we decide to remove this outdated RNG tracker type for now. If users have demands for exact match between PyTorch Distributed and Megatron on Random Number generation result, feel free to file an issue. **Impact** `TensorParallelRNGTracker` was only used when Tensor Parallel is used (i.e. calling `parallelize_module`). For non-FSDP users, the "replicas get the same random numbers and shards get different ones" remains unchanged. Unlike `TensorParallelRNGTracker` which sets different seeds (`base_seed + 2718 + TP_rank`) within the TP group, DTensor now sets the same seed (default value is 1234 but users can call `torch.distributed.tensor._random.manual_seed` to modify) on all ranks but choose the right RNG offset based on DTensor placements to enforce the "replicas get the same random numbers and shards get different ones" invariant. For FSDP2 users, improvement should be observed in a way that DTensor sharded within DP group now gets different random number sampling which `TensorParallelRNGTracker` failed to do, though we're not sure how much this change will improve the eventual training loss convergence. **Test** 1-d model weight meta init: `pytest test/distributed/_tensor/test_random_ops.py -s -k test_tp_model_meta_init` 2-d model weight meta init: `pytest test/distributed/_tensor/test_random_ops.py -s -k test_fsdp_tp_model_meta_init` TP model weight init test: `pytest test/distributed/tensor/parallel/test_tp_random_state.py` FSDP+TP model weight init test: `pytest test/distributed/_composable/fsdp/test_fully_shard_init.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/141220 Approved by: https://github.com/wconstab ghstack dependencies: #141731
2024-11-27 23:28:16 +00:00
# note: The RNG tracker will be initialized in distribute_tensor() call if it hasn't
# been initialized.
if isinstance(parallelize_plan, ParallelStyle):
parallelize_plan.src_data_rank = src_data_rank
[TP] fully rewrite Tensor Parallel APIs (#114732) This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs supposed to be a very thin-wrapper to DTensor APIs, but the current implementation got too messy and buggy. It's really hard to debug what went wrong when using it. It's crucially important for advanced users or developers to understand the API and its implementation easily without going through all different types of functions and utils, so that they could trust what happen under the hood. In particular this PR: * Make ParallelStyle to be a real contract API for parallelize_module to take, each concrete ParallelStyle only needs to implement `apply` to apply the sharding to nn.Module, remove all non-necessary fields. This also enable easier ParallelStyle authoring going forward. * Keep the ColwiseParallel and RowwiseParallel public interface, but refactor them in a way that makes the parameter sharding, inputs and outputs handling lives within the style itself, so that it's easy to understand how Linear/Embedding layers are sharded and how the inputs/outputs transformations are performed. * remove all those private _prepare_input/_prepare_output_fn fields for both ColwiseParallel/RowwiseParallel. Since we throw deprecation messages in nightly for a while and TP is on prototype release, the fields are also private, it should be safe to remove them * Refactor the recently landed PrepareModuleInput/Output style, change output_layouts to desired_input/output_layouts, group the function inside the style itself, no default arguments for these two styles and user need to specify them to think about the sharding layouts. Fixed bugs about not handling `use_local_output` flag. * Make default arguments be None instead of Placement object, this is standard python practice to not have custom object instance as default argument * Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel style, all prepare input/output functions) as we throw deprecation msgs for a while, and in the progress of removing all of them from the tests. * throw deprecation warning for `tp_mesh_dim` as we recomemnd use device mesh slice/indexing instead of manually specify mesh dim * Rewrite all documentations for every ParallelStyle and make the documentation more clear about what each style is doing TODOs: * Rewrite TP tests to adjust for the changes we have in this PR * add more tests to guard the bug fixes Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183) Pull Request resolved: https://github.com/pytorch/pytorch/pull/114732 Approved by: https://github.com/wz337, https://github.com/fduwjj
2023-12-02 04:53:26 +00:00
return parallelize_plan._apply(module, device_mesh)
elif isinstance(parallelize_plan, dict):
for module_path, parallelize_style in parallelize_plan.items():
path_splits = module_path.split(".")
if len(path_splits) == 0:
raise ValueError(
"Expect module path to be non-empty, but got empty string!"
)
while path_splits:
atom = path_splits.pop(0)
matched_children = filter(
# `t[0]` is child name
lambda t: fnmatch(t[0], atom),
module.named_children(),
)
# apply the plan to all matched submodules
for _, submodule in matched_children:
if path_splits:
# we haven't reached the leaf, apply in dict style
leaf_path = ".".join(
path_splits
) # rest of the path after `atom`
parallelize_module(
submodule,
device_mesh,
{leaf_path: parallelize_style},
src_data_rank=src_data_rank,
)
else:
# otherwise, directly apply style to this submodule
parallelize_module(
submodule,
device_mesh,
parallelize_style,
src_data_rank=src_data_rank,
)
return module
else:
raise TypeError( # pyre-ignore[7]
"Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"
f" parallelize_plan, {type(parallelize_plan)} found!"
)