pytorch/test/distributed/tensor/parallel/test_parallelize_api.py
Wanchao Liang 0431d47eaa [tp] propagate src_data_rank kwarg in TP API (#144005)
as titled, this PR propagates the src_data_rank in the TP API, so that
module level APIs could leverage the flexibility to choose
src_data_rank, and avoid the communication if it does not need to

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144005
Approved by: https://github.com/tianyu-l
ghstack dependencies: #143883
2025-01-02 05:35:52 +00:00

340 lines
12 KiB
Python

# Owner(s): ["oncall: distributed"]
from collections import OrderedDict
from copy import deepcopy
import torch
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel.api import parallelize_module
from torch.distributed.tensor.parallel.style import (
ColwiseParallel,
PrepareModuleInput,
PrepareModuleOutput,
RowwiseParallel,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
MLPModule,
MLPStacked,
with_comms,
)
class DummyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x
class TensorParallelAPITests(DTensorTestBase):
@property
def world_size(self):
gpu_num = torch.cuda.device_count()
return gpu_num if gpu_num % 2 == 0 and gpu_num > 4 else 4
def _compare_params(
self,
local_module,
dist_module,
rank0_only,
skip_rowwise_bias=False,
compare_grad=False,
):
replicate = [Replicate()]
for name, param in local_module.named_parameters():
dist_param = dist_module.get_parameter(name)
param = param.grad if compare_grad else param
dist_param = dist_param.grad if compare_grad else dist_param
if (
(not rank0_only)
or (self.rank == 0)
or (
name not in ["net2.bias"]
and not skip_rowwise_bias
or name not in ["bias", "net2.bias"]
)
):
self.assertEqual(
param,
dist_param.redistribute(
device_mesh=dist_param.device_mesh, placements=replicate
).to_local(),
f"{name} not equal between dist and non-dist",
)
def _compare_module(
self, local_module, dist_module, inp_size, rank0_only=True, rowwise=False
):
LR = 0.25 # the learning rate we use for testing
local_optim = torch.optim.SGD(local_module.parameters(), lr=LR)
dist_optim = torch.optim.SGD(dist_module.parameters(), lr=LR)
torch.manual_seed(0)
inp = torch.rand(*inp_size, device=self.device_type)
self._compare_params(local_module, dist_module, rank0_only)
# check forward correctness
local_output = local_module(inp)
inp = inp.chunk(self.world_size, dim=-1)[self.rank] if rowwise else inp
dist_output = dist_module(inp)
dist_output = (
dist_output.redistribute(dist_output.device_mesh, [Replicate()]).to_local()
if isinstance(dist_output, DTensor)
else dist_output
)
self.assertEqual(local_output, dist_output)
local_output.sum().backward()
dist_output.sum().backward()
# check backward and ensure gradients are same
self._compare_params(local_module, dist_module, rank0_only, rowwise, True)
local_optim.step()
dist_optim.step()
self._compare_params(local_module, dist_module, rank0_only, rowwise)
@with_comms
def test_parallelize_mlp_with_module_api(self):
inp_size = [12, 10]
model = MLPModule(self.device_type)
model_tp = deepcopy(model)
# Parallelize module.
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
model_tp = parallelize_module(
model_tp,
device_mesh,
{
"net1": ColwiseParallel(output_layouts=Replicate()),
"net2": ColwiseParallel(output_layouts=Replicate()),
},
)
self._compare_module(model, model_tp, inp_size, rank0_only=False)
@with_comms
def test_parallelize_mlp_with_module_api_nested(self):
inp_size = [12, 10]
model = torch.nn.Sequential(
OrderedDict([("dummy_encoder", MLPModule(self.device_type))])
)
model_tp = deepcopy(model)
# Parallelize module.
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
model_tp = parallelize_module(
model_tp,
device_mesh,
{
"dummy_encoder.net1": ColwiseParallel(output_layouts=Replicate()),
"dummy_encoder.net2": ColwiseParallel(output_layouts=Replicate()),
},
)
self._compare_module(model, model_tp, inp_size, rank0_only=False)
@with_comms
def test_linear_row_wise_parallel(self):
# test RowwiseParallel
inp_size = [9, 16]
rowwise = RowwiseParallel()
torch.manual_seed(5)
model = torch.nn.Linear(16, 10, device=self.device_type)
model_tp = deepcopy(model)
# parallelize model_tp
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
model_tp = parallelize_module(model_tp, device_mesh, rowwise)
# let each rank generate unique local input
torch.manual_seed(self.rank)
self._compare_module(model, model_tp, inp_size, rowwise=True)
@with_comms
def test_linear_col_wise_parallel(self):
# test ColwiseParallel
inp_size = [8, 10]
colwise = ColwiseParallel(output_layouts=Replicate())
torch.manual_seed(5)
model = torch.nn.Linear(10, 16, device=self.device_type)
model_tp = deepcopy(model)
# parallelize model_tp
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
model_tp = parallelize_module(model_tp, device_mesh, colwise)
self._compare_module(model, model_tp, inp_size)
@with_comms
def test_prepare_module_input(self):
module = DummyModule()
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
parallelize_module(
module,
device_mesh,
PrepareModuleInput(
input_layouts=Shard(0), desired_input_layouts=Replicate()
),
)
inp = torch.rand(5, 7, device=self.device_type)
output = module(inp).redistribute(device_mesh, [Shard(0)]).to_local()
self.assertEqual(inp, output)
@with_comms
def test_prepare_module_output(self):
module = DummyModule()
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
parallelize_module(
module,
device_mesh,
PrepareModuleOutput(
output_layouts=Replicate(), desired_output_layouts=Shard(0)
),
)
torch.manual_seed(15)
inp = torch.rand(16, 7, device=self.device_type)
dtensor = DTensor.from_local(inp, device_mesh, [Replicate()], run_check=False)
output = module(dtensor)
inp = dtensor.redistribute(device_mesh, [Shard(0)]).to_local()
self.assertEqual(inp, output)
@with_comms
def test_parallelize_module_with_star(self):
inp_size = [12, 10]
model = MLPModule(self.device_type)
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
model_tp = deepcopy(model)
model_tp = parallelize_module(
model_tp,
device_mesh,
{
"net*": ColwiseParallel(output_layouts=Replicate()),
},
)
self._compare_module(model, model_tp, inp_size, rank0_only=False)
@with_comms
def test_parallelize_module_src_data_rank(self):
# set seed different for each rank
torch.manual_seed(self.rank)
model = MLPModule(self.device_type)
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
comm_mode = CommDebugMode()
# test src_data_rank == 1
with comm_mode:
model_tp = deepcopy(model)
model_tp = parallelize_module(
model_tp,
device_mesh,
{
"net*": ColwiseParallel(output_layouts=Replicate()),
},
src_data_rank=1,
)
self.assertTrue(comm_mode.get_total_counts() > 0)
tp_full_params = [param.full_tensor() for param in model_tp.parameters()]
if self.rank == 1:
orig_model_params = list(model.parameters())
for idx, param in enumerate(tp_full_params):
self.assertEqual(param, orig_model_params[idx])
# test src_data_rank == None
model_tp_no_comm = deepcopy(model)
with comm_mode:
parallelize_module(
model_tp_no_comm,
device_mesh,
{
"net1": ColwiseParallel(),
"net2": RowwiseParallel(),
},
src_data_rank=None,
)
self.assertEqual(comm_mode.get_total_counts(), 0)
@with_comms
def test_parallelize_module_with_question(self):
inp_size = [12, 10]
model = MLPModule(self.device_type)
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
model_tp = deepcopy(model)
model_tp = parallelize_module(
model_tp,
device_mesh,
{
"net?": ColwiseParallel(output_layouts=Replicate()),
},
)
self._compare_module(model, model_tp, inp_size, rank0_only=False)
@with_comms
def test_parallelize_module_with_digit(self):
inp_size = [12, 10]
model = MLPModule(self.device_type)
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
model_tp = deepcopy(model)
model_tp = parallelize_module(
model_tp,
device_mesh,
{
"net[1-2]": ColwiseParallel(output_layouts=Replicate()),
},
)
self._compare_module(model, model_tp, inp_size, rank0_only=False)
@with_comms
def test_parallelize_module_multi_wildcard(self):
inp_size = [12, 10]
model = MLPStacked(self.device_type, n_layers=2)
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
model_tp = deepcopy(model)
model_tp = parallelize_module(
model_tp,
device_mesh,
{
"layers.*.net[1]": ColwiseParallel(),
"layers.*.net[2]": RowwiseParallel(),
},
)
self._compare_module(model, model_tp, inp_size, rank0_only=False)
@with_comms
def test_under_devicemesh_context(self):
# test ColwiseParallel
inp_size = [8, 10]
colwise = ColwiseParallel(output_layouts=Replicate())
torch.manual_seed(5)
model = torch.nn.Linear(10, 16, device=self.device_type)
model_tp = deepcopy(model)
# Call parallelize_module under DeviceMesh context.
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
with device_mesh:
model_tp = parallelize_module(model_tp, parallelize_plan=colwise)
self._compare_module(model, model_tp, inp_size)
@with_comms
def test_empty_plan(self):
torch.manual_seed(5)
model = torch.nn.Linear(10, 16, device=self.device_type)
# Call parallelize_module with empty plan.
# Goal is not to crash.
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
parallelize_module(model, device_mesh)
if __name__ == "__main__":
run_tests()