2022-11-19 18:01:25 +00:00
|
|
|
# Owner(s): ["oncall: distributed"]
|
2023-02-01 17:57:16 +00:00
|
|
|
from collections import OrderedDict
|
2024-04-02 14:00:25 +00:00
|
|
|
from copy import deepcopy
|
2022-11-19 18:01:25 +00:00
|
|
|
|
|
|
|
|
import torch
|
2023-10-14 00:20:50 +00:00
|
|
|
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
|
2024-12-30 23:01:05 +00:00
|
|
|
from torch.distributed.tensor.debug import CommDebugMode
|
2024-04-17 06:45:58 +00:00
|
|
|
from torch.distributed.tensor.parallel.api import parallelize_module
|
2022-11-30 19:24:01 +00:00
|
|
|
from torch.distributed.tensor.parallel.style import (
|
2022-11-23 05:29:53 +00:00
|
|
|
ColwiseParallel,
|
2023-10-14 00:20:50 +00:00
|
|
|
PrepareModuleInput,
|
|
|
|
|
PrepareModuleOutput,
|
2022-11-23 05:29:53 +00:00
|
|
|
RowwiseParallel,
|
|
|
|
|
)
|
2022-12-01 02:15:06 +00:00
|
|
|
from torch.testing._internal.common_utils import run_tests
|
|
|
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
|
|
|
DTensorTestBase,
|
2023-06-22 23:37:28 +00:00
|
|
|
MLPModule,
|
2024-04-02 14:00:25 +00:00
|
|
|
MLPStacked,
|
2022-12-01 02:15:06 +00:00
|
|
|
with_comms,
|
2022-11-19 18:01:25 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2023-10-14 00:20:50 +00:00
|
|
|
class DummyModule(torch.nn.Module):
|
2024-08-01 07:22:48 +00:00
|
|
|
def __init__(self) -> None:
|
2023-10-14 00:20:50 +00:00
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
2022-11-19 18:01:25 +00:00
|
|
|
class TensorParallelAPITests(DTensorTestBase):
|
|
|
|
|
@property
|
|
|
|
|
def world_size(self):
|
|
|
|
|
gpu_num = torch.cuda.device_count()
|
|
|
|
|
return gpu_num if gpu_num % 2 == 0 and gpu_num > 4 else 4
|
|
|
|
|
|
2022-11-23 05:29:53 +00:00
|
|
|
def _compare_params(
|
|
|
|
|
self,
|
|
|
|
|
local_module,
|
|
|
|
|
dist_module,
|
2023-01-26 05:19:31 +00:00
|
|
|
rank0_only,
|
2022-11-23 05:29:53 +00:00
|
|
|
skip_rowwise_bias=False,
|
|
|
|
|
compare_grad=False,
|
|
|
|
|
):
|
|
|
|
|
replicate = [Replicate()]
|
|
|
|
|
for name, param in local_module.named_parameters():
|
|
|
|
|
dist_param = dist_module.get_parameter(name)
|
|
|
|
|
param = param.grad if compare_grad else param
|
|
|
|
|
dist_param = dist_param.grad if compare_grad else dist_param
|
2023-02-01 17:57:16 +00:00
|
|
|
if (
|
|
|
|
|
(not rank0_only)
|
|
|
|
|
or (self.rank == 0)
|
|
|
|
|
or (
|
|
|
|
|
name not in ["net2.bias"]
|
|
|
|
|
and not skip_rowwise_bias
|
|
|
|
|
or name not in ["bias", "net2.bias"]
|
|
|
|
|
)
|
2022-11-23 05:29:53 +00:00
|
|
|
):
|
|
|
|
|
self.assertEqual(
|
|
|
|
|
param,
|
|
|
|
|
dist_param.redistribute(
|
|
|
|
|
device_mesh=dist_param.device_mesh, placements=replicate
|
|
|
|
|
).to_local(),
|
2023-02-01 17:57:16 +00:00
|
|
|
f"{name} not equal between dist and non-dist",
|
2022-11-23 05:29:53 +00:00
|
|
|
)
|
|
|
|
|
|
2023-02-01 17:57:16 +00:00
|
|
|
def _compare_module(
|
|
|
|
|
self, local_module, dist_module, inp_size, rank0_only=True, rowwise=False
|
|
|
|
|
):
|
2022-11-23 05:29:53 +00:00
|
|
|
LR = 0.25 # the learning rate we use for testing
|
|
|
|
|
local_optim = torch.optim.SGD(local_module.parameters(), lr=LR)
|
|
|
|
|
dist_optim = torch.optim.SGD(dist_module.parameters(), lr=LR)
|
|
|
|
|
torch.manual_seed(0)
|
|
|
|
|
inp = torch.rand(*inp_size, device=self.device_type)
|
2023-01-26 05:19:31 +00:00
|
|
|
self._compare_params(local_module, dist_module, rank0_only)
|
2022-11-23 05:29:53 +00:00
|
|
|
|
|
|
|
|
# check forward correctness
|
|
|
|
|
local_output = local_module(inp)
|
|
|
|
|
inp = inp.chunk(self.world_size, dim=-1)[self.rank] if rowwise else inp
|
|
|
|
|
dist_output = dist_module(inp)
|
|
|
|
|
dist_output = (
|
2023-05-06 00:58:08 +00:00
|
|
|
dist_output.redistribute(dist_output.device_mesh, [Replicate()]).to_local()
|
|
|
|
|
if isinstance(dist_output, DTensor)
|
|
|
|
|
else dist_output
|
2022-11-23 05:29:53 +00:00
|
|
|
)
|
|
|
|
|
self.assertEqual(local_output, dist_output)
|
|
|
|
|
|
|
|
|
|
local_output.sum().backward()
|
|
|
|
|
dist_output.sum().backward()
|
|
|
|
|
|
|
|
|
|
# check backward and ensure gradients are same
|
2023-01-26 05:19:31 +00:00
|
|
|
self._compare_params(local_module, dist_module, rank0_only, rowwise, True)
|
2022-11-23 05:29:53 +00:00
|
|
|
|
|
|
|
|
local_optim.step()
|
|
|
|
|
dist_optim.step()
|
2023-01-26 05:19:31 +00:00
|
|
|
self._compare_params(local_module, dist_module, rank0_only, rowwise)
|
2022-11-23 05:29:53 +00:00
|
|
|
|
2023-01-26 05:19:31 +00:00
|
|
|
@with_comms
|
|
|
|
|
def test_parallelize_mlp_with_module_api(self):
|
|
|
|
|
inp_size = [12, 10]
|
|
|
|
|
model = MLPModule(self.device_type)
|
2024-04-03 14:10:13 +00:00
|
|
|
model_tp = deepcopy(model)
|
2023-01-26 05:19:31 +00:00
|
|
|
|
|
|
|
|
# Parallelize module.
|
|
|
|
|
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
2023-02-01 17:57:16 +00:00
|
|
|
model_tp = parallelize_module(
|
|
|
|
|
model_tp,
|
|
|
|
|
device_mesh,
|
2023-05-06 00:58:08 +00:00
|
|
|
{
|
2023-11-29 21:45:21 +00:00
|
|
|
"net1": ColwiseParallel(output_layouts=Replicate()),
|
|
|
|
|
"net2": ColwiseParallel(output_layouts=Replicate()),
|
2023-05-06 00:58:08 +00:00
|
|
|
},
|
2023-02-01 17:57:16 +00:00
|
|
|
)
|
|
|
|
|
self._compare_module(model, model_tp, inp_size, rank0_only=False)
|
|
|
|
|
|
|
|
|
|
@with_comms
|
|
|
|
|
def test_parallelize_mlp_with_module_api_nested(self):
|
|
|
|
|
inp_size = [12, 10]
|
|
|
|
|
model = torch.nn.Sequential(
|
|
|
|
|
OrderedDict([("dummy_encoder", MLPModule(self.device_type))])
|
|
|
|
|
)
|
2024-04-03 14:10:13 +00:00
|
|
|
model_tp = deepcopy(model)
|
2023-02-01 17:57:16 +00:00
|
|
|
|
|
|
|
|
# Parallelize module.
|
|
|
|
|
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
|
model_tp = parallelize_module(
|
|
|
|
|
model_tp,
|
|
|
|
|
device_mesh,
|
|
|
|
|
{
|
2023-11-29 21:45:21 +00:00
|
|
|
"dummy_encoder.net1": ColwiseParallel(output_layouts=Replicate()),
|
|
|
|
|
"dummy_encoder.net2": ColwiseParallel(output_layouts=Replicate()),
|
2023-02-01 17:57:16 +00:00
|
|
|
},
|
|
|
|
|
)
|
2023-01-26 05:19:31 +00:00
|
|
|
self._compare_module(model, model_tp, inp_size, rank0_only=False)
|
|
|
|
|
|
2022-11-23 05:29:53 +00:00
|
|
|
@with_comms
|
|
|
|
|
def test_linear_row_wise_parallel(self):
|
|
|
|
|
# test RowwiseParallel
|
|
|
|
|
inp_size = [9, 16]
|
|
|
|
|
rowwise = RowwiseParallel()
|
|
|
|
|
|
|
|
|
|
torch.manual_seed(5)
|
|
|
|
|
model = torch.nn.Linear(16, 10, device=self.device_type)
|
2024-04-03 14:10:13 +00:00
|
|
|
model_tp = deepcopy(model)
|
2022-11-23 05:29:53 +00:00
|
|
|
|
|
|
|
|
# parallelize model_tp
|
|
|
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
[TP] fully rewrite Tensor Parallel APIs (#114732)
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs
supposed to be a very thin-wrapper to DTensor APIs, but the current
implementation got too messy and buggy. It's really hard to debug what
went wrong when using it. It's crucially important for advanced users or
developers to understand the API and its implementation easily without
going through all different types of functions and utils, so that
they could trust what happen under the hood.
In particular this PR:
* Make ParallelStyle to be a real contract API for parallelize_module to
take, each concrete ParallelStyle only needs to implement `apply` to
apply the sharding to nn.Module, remove all non-necessary fields. This
also enable easier ParallelStyle authoring going forward.
* Keep the ColwiseParallel and RowwiseParallel public interface, but
refactor them in a way that makes the parameter sharding, inputs and
outputs handling lives within the style itself, so that it's easy to
understand how Linear/Embedding layers are sharded and how the inputs/outputs
transformations are performed.
* remove all those private _prepare_input/_prepare_output_fn fields for
both ColwiseParallel/RowwiseParallel. Since we throw deprecation
messages in nightly for a while and TP is on prototype release, the
fields are also private, it should be safe to remove them
* Refactor the recently landed PrepareModuleInput/Output style, change
output_layouts to desired_input/output_layouts, group
the function inside the style itself, no default arguments for these
two styles and user need to specify them to think about the sharding
layouts. Fixed bugs about not handling
`use_local_output` flag.
* Make default arguments be None instead of Placement object, this is
standard python practice to not have custom object instance as default
argument
* Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel
style, all prepare input/output functions) as we throw deprecation
msgs for a while, and in the progress of removing all of them from the tests.
* throw deprecation warning for `tp_mesh_dim` as we recomemnd use device
mesh slice/indexing instead of manually specify mesh dim
* Rewrite all documentations for every ParallelStyle and make the
documentation more clear about what each style is doing
TODOs:
* Rewrite TP tests to adjust for the changes we have in this PR
* add more tests to guard the bug fixes
Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114732
Approved by: https://github.com/wz337, https://github.com/fduwjj
2023-12-02 04:53:26 +00:00
|
|
|
model_tp = parallelize_module(model_tp, device_mesh, rowwise)
|
2022-11-23 05:29:53 +00:00
|
|
|
|
|
|
|
|
# let each rank generate unique local input
|
|
|
|
|
torch.manual_seed(self.rank)
|
2023-01-26 05:19:31 +00:00
|
|
|
self._compare_module(model, model_tp, inp_size, rowwise=True)
|
2022-11-23 05:29:53 +00:00
|
|
|
|
|
|
|
|
@with_comms
|
|
|
|
|
def test_linear_col_wise_parallel(self):
|
|
|
|
|
# test ColwiseParallel
|
|
|
|
|
inp_size = [8, 10]
|
2023-11-29 21:45:21 +00:00
|
|
|
colwise = ColwiseParallel(output_layouts=Replicate())
|
2022-11-23 05:29:53 +00:00
|
|
|
|
|
|
|
|
torch.manual_seed(5)
|
|
|
|
|
model = torch.nn.Linear(10, 16, device=self.device_type)
|
2024-04-03 14:10:13 +00:00
|
|
|
model_tp = deepcopy(model)
|
2022-11-23 05:29:53 +00:00
|
|
|
|
|
|
|
|
# parallelize model_tp
|
|
|
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
[TP] fully rewrite Tensor Parallel APIs (#114732)
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs
supposed to be a very thin-wrapper to DTensor APIs, but the current
implementation got too messy and buggy. It's really hard to debug what
went wrong when using it. It's crucially important for advanced users or
developers to understand the API and its implementation easily without
going through all different types of functions and utils, so that
they could trust what happen under the hood.
In particular this PR:
* Make ParallelStyle to be a real contract API for parallelize_module to
take, each concrete ParallelStyle only needs to implement `apply` to
apply the sharding to nn.Module, remove all non-necessary fields. This
also enable easier ParallelStyle authoring going forward.
* Keep the ColwiseParallel and RowwiseParallel public interface, but
refactor them in a way that makes the parameter sharding, inputs and
outputs handling lives within the style itself, so that it's easy to
understand how Linear/Embedding layers are sharded and how the inputs/outputs
transformations are performed.
* remove all those private _prepare_input/_prepare_output_fn fields for
both ColwiseParallel/RowwiseParallel. Since we throw deprecation
messages in nightly for a while and TP is on prototype release, the
fields are also private, it should be safe to remove them
* Refactor the recently landed PrepareModuleInput/Output style, change
output_layouts to desired_input/output_layouts, group
the function inside the style itself, no default arguments for these
two styles and user need to specify them to think about the sharding
layouts. Fixed bugs about not handling
`use_local_output` flag.
* Make default arguments be None instead of Placement object, this is
standard python practice to not have custom object instance as default
argument
* Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel
style, all prepare input/output functions) as we throw deprecation
msgs for a while, and in the progress of removing all of them from the tests.
* throw deprecation warning for `tp_mesh_dim` as we recomemnd use device
mesh slice/indexing instead of manually specify mesh dim
* Rewrite all documentations for every ParallelStyle and make the
documentation more clear about what each style is doing
TODOs:
* Rewrite TP tests to adjust for the changes we have in this PR
* add more tests to guard the bug fixes
Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114732
Approved by: https://github.com/wz337, https://github.com/fduwjj
2023-12-02 04:53:26 +00:00
|
|
|
model_tp = parallelize_module(model_tp, device_mesh, colwise)
|
2022-11-23 05:29:53 +00:00
|
|
|
|
|
|
|
|
self._compare_module(model, model_tp, inp_size)
|
|
|
|
|
|
2023-10-14 00:20:50 +00:00
|
|
|
@with_comms
|
|
|
|
|
def test_prepare_module_input(self):
|
|
|
|
|
module = DummyModule()
|
|
|
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
[TP] fully rewrite Tensor Parallel APIs (#114732)
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs
supposed to be a very thin-wrapper to DTensor APIs, but the current
implementation got too messy and buggy. It's really hard to debug what
went wrong when using it. It's crucially important for advanced users or
developers to understand the API and its implementation easily without
going through all different types of functions and utils, so that
they could trust what happen under the hood.
In particular this PR:
* Make ParallelStyle to be a real contract API for parallelize_module to
take, each concrete ParallelStyle only needs to implement `apply` to
apply the sharding to nn.Module, remove all non-necessary fields. This
also enable easier ParallelStyle authoring going forward.
* Keep the ColwiseParallel and RowwiseParallel public interface, but
refactor them in a way that makes the parameter sharding, inputs and
outputs handling lives within the style itself, so that it's easy to
understand how Linear/Embedding layers are sharded and how the inputs/outputs
transformations are performed.
* remove all those private _prepare_input/_prepare_output_fn fields for
both ColwiseParallel/RowwiseParallel. Since we throw deprecation
messages in nightly for a while and TP is on prototype release, the
fields are also private, it should be safe to remove them
* Refactor the recently landed PrepareModuleInput/Output style, change
output_layouts to desired_input/output_layouts, group
the function inside the style itself, no default arguments for these
two styles and user need to specify them to think about the sharding
layouts. Fixed bugs about not handling
`use_local_output` flag.
* Make default arguments be None instead of Placement object, this is
standard python practice to not have custom object instance as default
argument
* Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel
style, all prepare input/output functions) as we throw deprecation
msgs for a while, and in the progress of removing all of them from the tests.
* throw deprecation warning for `tp_mesh_dim` as we recomemnd use device
mesh slice/indexing instead of manually specify mesh dim
* Rewrite all documentations for every ParallelStyle and make the
documentation more clear about what each style is doing
TODOs:
* Rewrite TP tests to adjust for the changes we have in this PR
* add more tests to guard the bug fixes
Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114732
Approved by: https://github.com/wz337, https://github.com/fduwjj
2023-12-02 04:53:26 +00:00
|
|
|
parallelize_module(
|
|
|
|
|
module,
|
|
|
|
|
device_mesh,
|
|
|
|
|
PrepareModuleInput(
|
2024-04-17 06:45:58 +00:00
|
|
|
input_layouts=Shard(0), desired_input_layouts=Replicate()
|
|
|
|
|
),
|
[TP] fully rewrite Tensor Parallel APIs (#114732)
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs
supposed to be a very thin-wrapper to DTensor APIs, but the current
implementation got too messy and buggy. It's really hard to debug what
went wrong when using it. It's crucially important for advanced users or
developers to understand the API and its implementation easily without
going through all different types of functions and utils, so that
they could trust what happen under the hood.
In particular this PR:
* Make ParallelStyle to be a real contract API for parallelize_module to
take, each concrete ParallelStyle only needs to implement `apply` to
apply the sharding to nn.Module, remove all non-necessary fields. This
also enable easier ParallelStyle authoring going forward.
* Keep the ColwiseParallel and RowwiseParallel public interface, but
refactor them in a way that makes the parameter sharding, inputs and
outputs handling lives within the style itself, so that it's easy to
understand how Linear/Embedding layers are sharded and how the inputs/outputs
transformations are performed.
* remove all those private _prepare_input/_prepare_output_fn fields for
both ColwiseParallel/RowwiseParallel. Since we throw deprecation
messages in nightly for a while and TP is on prototype release, the
fields are also private, it should be safe to remove them
* Refactor the recently landed PrepareModuleInput/Output style, change
output_layouts to desired_input/output_layouts, group
the function inside the style itself, no default arguments for these
two styles and user need to specify them to think about the sharding
layouts. Fixed bugs about not handling
`use_local_output` flag.
* Make default arguments be None instead of Placement object, this is
standard python practice to not have custom object instance as default
argument
* Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel
style, all prepare input/output functions) as we throw deprecation
msgs for a while, and in the progress of removing all of them from the tests.
* throw deprecation warning for `tp_mesh_dim` as we recomemnd use device
mesh slice/indexing instead of manually specify mesh dim
* Rewrite all documentations for every ParallelStyle and make the
documentation more clear about what each style is doing
TODOs:
* Rewrite TP tests to adjust for the changes we have in this PR
* add more tests to guard the bug fixes
Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114732
Approved by: https://github.com/wz337, https://github.com/fduwjj
2023-12-02 04:53:26 +00:00
|
|
|
)
|
2023-10-14 00:20:50 +00:00
|
|
|
inp = torch.rand(5, 7, device=self.device_type)
|
|
|
|
|
output = module(inp).redistribute(device_mesh, [Shard(0)]).to_local()
|
|
|
|
|
self.assertEqual(inp, output)
|
|
|
|
|
|
|
|
|
|
@with_comms
|
|
|
|
|
def test_prepare_module_output(self):
|
|
|
|
|
module = DummyModule()
|
|
|
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
[TP] fully rewrite Tensor Parallel APIs (#114732)
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs
supposed to be a very thin-wrapper to DTensor APIs, but the current
implementation got too messy and buggy. It's really hard to debug what
went wrong when using it. It's crucially important for advanced users or
developers to understand the API and its implementation easily without
going through all different types of functions and utils, so that
they could trust what happen under the hood.
In particular this PR:
* Make ParallelStyle to be a real contract API for parallelize_module to
take, each concrete ParallelStyle only needs to implement `apply` to
apply the sharding to nn.Module, remove all non-necessary fields. This
also enable easier ParallelStyle authoring going forward.
* Keep the ColwiseParallel and RowwiseParallel public interface, but
refactor them in a way that makes the parameter sharding, inputs and
outputs handling lives within the style itself, so that it's easy to
understand how Linear/Embedding layers are sharded and how the inputs/outputs
transformations are performed.
* remove all those private _prepare_input/_prepare_output_fn fields for
both ColwiseParallel/RowwiseParallel. Since we throw deprecation
messages in nightly for a while and TP is on prototype release, the
fields are also private, it should be safe to remove them
* Refactor the recently landed PrepareModuleInput/Output style, change
output_layouts to desired_input/output_layouts, group
the function inside the style itself, no default arguments for these
two styles and user need to specify them to think about the sharding
layouts. Fixed bugs about not handling
`use_local_output` flag.
* Make default arguments be None instead of Placement object, this is
standard python practice to not have custom object instance as default
argument
* Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel
style, all prepare input/output functions) as we throw deprecation
msgs for a while, and in the progress of removing all of them from the tests.
* throw deprecation warning for `tp_mesh_dim` as we recomemnd use device
mesh slice/indexing instead of manually specify mesh dim
* Rewrite all documentations for every ParallelStyle and make the
documentation more clear about what each style is doing
TODOs:
* Rewrite TP tests to adjust for the changes we have in this PR
* add more tests to guard the bug fixes
Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114732
Approved by: https://github.com/wz337, https://github.com/fduwjj
2023-12-02 04:53:26 +00:00
|
|
|
parallelize_module(
|
|
|
|
|
module,
|
|
|
|
|
device_mesh,
|
|
|
|
|
PrepareModuleOutput(
|
2024-04-17 06:45:58 +00:00
|
|
|
output_layouts=Replicate(), desired_output_layouts=Shard(0)
|
|
|
|
|
),
|
[TP] fully rewrite Tensor Parallel APIs (#114732)
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs
supposed to be a very thin-wrapper to DTensor APIs, but the current
implementation got too messy and buggy. It's really hard to debug what
went wrong when using it. It's crucially important for advanced users or
developers to understand the API and its implementation easily without
going through all different types of functions and utils, so that
they could trust what happen under the hood.
In particular this PR:
* Make ParallelStyle to be a real contract API for parallelize_module to
take, each concrete ParallelStyle only needs to implement `apply` to
apply the sharding to nn.Module, remove all non-necessary fields. This
also enable easier ParallelStyle authoring going forward.
* Keep the ColwiseParallel and RowwiseParallel public interface, but
refactor them in a way that makes the parameter sharding, inputs and
outputs handling lives within the style itself, so that it's easy to
understand how Linear/Embedding layers are sharded and how the inputs/outputs
transformations are performed.
* remove all those private _prepare_input/_prepare_output_fn fields for
both ColwiseParallel/RowwiseParallel. Since we throw deprecation
messages in nightly for a while and TP is on prototype release, the
fields are also private, it should be safe to remove them
* Refactor the recently landed PrepareModuleInput/Output style, change
output_layouts to desired_input/output_layouts, group
the function inside the style itself, no default arguments for these
two styles and user need to specify them to think about the sharding
layouts. Fixed bugs about not handling
`use_local_output` flag.
* Make default arguments be None instead of Placement object, this is
standard python practice to not have custom object instance as default
argument
* Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel
style, all prepare input/output functions) as we throw deprecation
msgs for a while, and in the progress of removing all of them from the tests.
* throw deprecation warning for `tp_mesh_dim` as we recomemnd use device
mesh slice/indexing instead of manually specify mesh dim
* Rewrite all documentations for every ParallelStyle and make the
documentation more clear about what each style is doing
TODOs:
* Rewrite TP tests to adjust for the changes we have in this PR
* add more tests to guard the bug fixes
Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114732
Approved by: https://github.com/wz337, https://github.com/fduwjj
2023-12-02 04:53:26 +00:00
|
|
|
)
|
2023-10-14 00:20:50 +00:00
|
|
|
torch.manual_seed(15)
|
|
|
|
|
inp = torch.rand(16, 7, device=self.device_type)
|
|
|
|
|
dtensor = DTensor.from_local(inp, device_mesh, [Replicate()], run_check=False)
|
|
|
|
|
output = module(dtensor)
|
|
|
|
|
inp = dtensor.redistribute(device_mesh, [Shard(0)]).to_local()
|
|
|
|
|
self.assertEqual(inp, output)
|
|
|
|
|
|
2024-04-02 14:00:25 +00:00
|
|
|
@with_comms
|
|
|
|
|
def test_parallelize_module_with_star(self):
|
|
|
|
|
inp_size = [12, 10]
|
|
|
|
|
model = MLPModule(self.device_type)
|
|
|
|
|
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
|
|
|
|
|
|
model_tp = deepcopy(model)
|
|
|
|
|
model_tp = parallelize_module(
|
|
|
|
|
model_tp,
|
|
|
|
|
device_mesh,
|
|
|
|
|
{
|
|
|
|
|
"net*": ColwiseParallel(output_layouts=Replicate()),
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
self._compare_module(model, model_tp, inp_size, rank0_only=False)
|
|
|
|
|
|
2024-12-30 23:01:05 +00:00
|
|
|
@with_comms
|
|
|
|
|
def test_parallelize_module_src_data_rank(self):
|
|
|
|
|
# set seed different for each rank
|
|
|
|
|
torch.manual_seed(self.rank)
|
|
|
|
|
model = MLPModule(self.device_type)
|
|
|
|
|
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
|
|
|
|
|
|
comm_mode = CommDebugMode()
|
|
|
|
|
|
|
|
|
|
# test src_data_rank == 1
|
|
|
|
|
with comm_mode:
|
|
|
|
|
model_tp = deepcopy(model)
|
|
|
|
|
model_tp = parallelize_module(
|
|
|
|
|
model_tp,
|
|
|
|
|
device_mesh,
|
|
|
|
|
{
|
|
|
|
|
"net*": ColwiseParallel(output_layouts=Replicate()),
|
|
|
|
|
},
|
|
|
|
|
src_data_rank=1,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.assertTrue(comm_mode.get_total_counts() > 0)
|
|
|
|
|
tp_full_params = [param.full_tensor() for param in model_tp.parameters()]
|
|
|
|
|
if self.rank == 1:
|
|
|
|
|
orig_model_params = list(model.parameters())
|
|
|
|
|
for idx, param in enumerate(tp_full_params):
|
|
|
|
|
self.assertEqual(param, orig_model_params[idx])
|
|
|
|
|
|
|
|
|
|
# test src_data_rank == None
|
|
|
|
|
model_tp_no_comm = deepcopy(model)
|
|
|
|
|
with comm_mode:
|
|
|
|
|
parallelize_module(
|
|
|
|
|
model_tp_no_comm,
|
|
|
|
|
device_mesh,
|
|
|
|
|
{
|
|
|
|
|
"net1": ColwiseParallel(),
|
|
|
|
|
"net2": RowwiseParallel(),
|
|
|
|
|
},
|
|
|
|
|
src_data_rank=None,
|
|
|
|
|
)
|
|
|
|
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
|
|
|
|
|
2024-04-02 14:00:25 +00:00
|
|
|
@with_comms
|
|
|
|
|
def test_parallelize_module_with_question(self):
|
|
|
|
|
inp_size = [12, 10]
|
|
|
|
|
model = MLPModule(self.device_type)
|
|
|
|
|
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
|
|
|
|
|
|
model_tp = deepcopy(model)
|
|
|
|
|
model_tp = parallelize_module(
|
|
|
|
|
model_tp,
|
|
|
|
|
device_mesh,
|
|
|
|
|
{
|
|
|
|
|
"net?": ColwiseParallel(output_layouts=Replicate()),
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
self._compare_module(model, model_tp, inp_size, rank0_only=False)
|
|
|
|
|
|
|
|
|
|
@with_comms
|
|
|
|
|
def test_parallelize_module_with_digit(self):
|
|
|
|
|
inp_size = [12, 10]
|
|
|
|
|
model = MLPModule(self.device_type)
|
|
|
|
|
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
|
|
|
|
|
|
model_tp = deepcopy(model)
|
|
|
|
|
model_tp = parallelize_module(
|
|
|
|
|
model_tp,
|
|
|
|
|
device_mesh,
|
|
|
|
|
{
|
|
|
|
|
"net[1-2]": ColwiseParallel(output_layouts=Replicate()),
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
self._compare_module(model, model_tp, inp_size, rank0_only=False)
|
|
|
|
|
|
|
|
|
|
@with_comms
|
|
|
|
|
def test_parallelize_module_multi_wildcard(self):
|
|
|
|
|
inp_size = [12, 10]
|
2024-04-03 03:08:04 +00:00
|
|
|
model = MLPStacked(self.device_type, n_layers=2)
|
2024-04-02 14:00:25 +00:00
|
|
|
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
|
|
|
|
|
|
model_tp = deepcopy(model)
|
|
|
|
|
model_tp = parallelize_module(
|
|
|
|
|
model_tp,
|
|
|
|
|
device_mesh,
|
|
|
|
|
{
|
2024-04-03 03:08:04 +00:00
|
|
|
"layers.*.net[1]": ColwiseParallel(),
|
|
|
|
|
"layers.*.net[2]": RowwiseParallel(),
|
2024-04-02 14:00:25 +00:00
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
self._compare_module(model, model_tp, inp_size, rank0_only=False)
|
|
|
|
|
|
Allow parallelize_module to get device_mesh from ambient context (#134247)
This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one.
Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model.
(The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.)
For example:
```
class FeedForward(nn.Module):
def __init__(self, config: TransformerArgs) -> None:
super().__init__()
w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w1 = parallelize_module(w1, Colwise)
self.w2 = parallelize_module(w2, Rowwise)
self.w3 = parallelize_module(w3, Colwise)
def forward(self, x: Tensor) -> Tensor:
y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x))
# y is a DTensor with Partial placement; we can return it as is.
return y
# Or we can convert it to Replicate -- there is modeling flexibility here.
return y.redistribute(Replicate())
with device_mesh:
model = FeedForward(config)
# Now model is a model parallelized onto device_mesh
y = model(x)
```
The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context.
Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134247
Approved by: https://github.com/tianyu-l
2024-10-08 19:49:33 +00:00
|
|
|
@with_comms
|
|
|
|
|
def test_under_devicemesh_context(self):
|
|
|
|
|
# test ColwiseParallel
|
|
|
|
|
inp_size = [8, 10]
|
|
|
|
|
colwise = ColwiseParallel(output_layouts=Replicate())
|
|
|
|
|
|
|
|
|
|
torch.manual_seed(5)
|
|
|
|
|
model = torch.nn.Linear(10, 16, device=self.device_type)
|
|
|
|
|
model_tp = deepcopy(model)
|
|
|
|
|
|
|
|
|
|
# Call parallelize_module under DeviceMesh context.
|
|
|
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
|
|
|
|
with device_mesh:
|
|
|
|
|
model_tp = parallelize_module(model_tp, parallelize_plan=colwise)
|
|
|
|
|
|
|
|
|
|
self._compare_module(model, model_tp, inp_size)
|
|
|
|
|
|
|
|
|
|
@with_comms
|
|
|
|
|
def test_empty_plan(self):
|
|
|
|
|
torch.manual_seed(5)
|
|
|
|
|
model = torch.nn.Linear(10, 16, device=self.device_type)
|
|
|
|
|
|
|
|
|
|
# Call parallelize_module with empty plan.
|
|
|
|
|
# Goal is not to crash.
|
|
|
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
|
|
|
|
parallelize_module(model, device_mesh)
|
|
|
|
|
|
2022-11-19 18:01:25 +00:00
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
run_tests()
|