pytorch/test/distributed/tensor/parallel/test_tp_examples.py
wz337 c1dc8bb858 [DTensor] Turn on foreach implementation of optimizer for DTensor by default (#123394)
Append DTensor to the optimizer `_foreach_supported_types` and turn on foreach implementation of optimizer for DTensor if not specified by the users.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123394
Approved by: https://github.com/wanchaol
2024-05-15 16:45:42 +00:00

395 lines
14 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import itertools
from copy import deepcopy
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed._tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Replicate,
Shard,
)
from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.tensor.parallel import (
ColwiseParallel,
loss_parallel,
parallelize_module,
RowwiseParallel,
)
from torch.distributed.tensor.parallel.input_reshard import input_reshard
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
MLPModule,
ModelArgs,
NUM_DEVICES,
skip_unless_torch_gpu,
Transformer,
with_comms,
)
c10d_functional = torch.ops.c10d_functional
class DistTensorParallelExampleTest(DTensorTestBase):
def _check_module(self, m1, m2, check_grad=False):
named_parameters = dict(m1.named_parameters())
for name, param_m2 in m2.named_parameters():
self.assertTrue(name in named_parameters)
param_m1 = named_parameters[name]
if check_grad:
param_m2 = param_m2.grad
param_m1 = param_m1.grad
if isinstance(param_m2, DTensor):
replicate = [Replicate()]
param_m2 = param_m2.redistribute(
device_mesh=param_m2.device_mesh, placements=replicate
).to_local()
self.assertEqual(param_m2, param_m1)
def _test_mlp_training_e2e(self, is_seq_parallel=False, recompute_activation=False):
inp_size = [8, 10]
# Ensure all tp ranks have same input.
rng_seed = self.rank if is_seq_parallel else 0
torch.manual_seed(rng_seed)
inp = torch.rand(*inp_size, device=self.device_type)
model = MLPModule(self.device_type)
model_tp = deepcopy(model)
# Ensure model are initialized the same way.
self._check_module(model, model_tp)
# Shard module and initialize optimizer.
LR = 0.25
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, NUM_DEVICES),
)
parallelize_plan = {
"net1": ColwiseParallel(input_layouts=Shard(0))
if is_seq_parallel
else ColwiseParallel(),
"net2": RowwiseParallel(output_layouts=Shard(0))
if is_seq_parallel
else RowwiseParallel(),
}
model_tp = parallelize_module(model_tp, device_mesh, parallelize_plan)
if recompute_activation:
model_tp = input_reshard(
checkpoint_wrapper(
model_tp, checkpoint_impl=CheckpointImpl.NO_REENTRANT
),
device_mesh,
None if is_seq_parallel else 0,
)
optim = torch.optim.SGD(model.parameters(), lr=LR)
optim_tp = torch.optim.SGD(model_tp.parameters(), lr=LR)
output = model(inp)
output.sum().backward()
from torch.distributed._tensor.debug import CommDebugMode
comm_mode = CommDebugMode()
with comm_mode:
output_tp = model_tp(inp)
output_tp.sum().backward()
self.assertEqual(output, output_tp)
if is_seq_parallel:
self.assertEqual(
comm_mode.get_comm_counts()[c10d_functional.all_gather_into_tensor], 2
)
self.assertEqual(
comm_mode.get_comm_counts()[c10d_functional.reduce_scatter_tensor], 1
)
else:
self.assertEqual(comm_mode.get_comm_counts()[c10d_functional.all_reduce], 1)
if is_seq_parallel:
# Sum gradients from different ranks, since input
# are different across ranks for sequence parallel.
dist.all_reduce(model.net1.weight.grad)
dist.all_reduce(model.net1.bias.grad)
dist.all_reduce(model.net2.weight.grad)
dist.all_reduce(model.net2.bias.grad)
# Ensure gradients are same.
self._check_module(model, model_tp, check_grad=True)
optim.step()
optim_tp.step()
# Ensure model weights are still same after update.
# Due to the trick we use for Partial aggregation, we only check the weight when local_rank = 0.
self._check_module(model, model_tp)
inp = torch.rand(*inp_size, device=self.device_type)
output = model(inp)
output_tp = model_tp(inp)
self.assertEqual(output, output_tp)
def _test_mlp_inference(self, device_mesh):
inp_size = [8, 10]
# Ensure all tp ranks have same input.
torch.manual_seed(0)
inp = torch.rand(*inp_size, device=self.device_type)
model = MLPModule(self.device_type)
model_tp = deepcopy(model)
# Ensure model are initialized the same way.
self._check_module(model, model_tp)
# Shard module and initialize optimizer.
parallelize_plan = {
"net1": ColwiseParallel(),
"net2": RowwiseParallel(),
}
model_tp = parallelize_module(model_tp, device_mesh, parallelize_plan)
output = model(inp)
output_tp = model_tp(inp)
self.assertEqual(output, output_tp)
@with_comms
@parametrize("is_seq_parallel", [True, False])
# TODO: need to revisit input_reshard API about why it failed multi-gpu tests.
# @parametrize("recompute_activation", [True, False])
@parametrize("recompute_activation", [False])
def test_mlp_training(self, is_seq_parallel, recompute_activation):
self._test_mlp_training_e2e(
is_seq_parallel=is_seq_parallel, recompute_activation=recompute_activation
)
@with_comms
def test_mlp_inference(self):
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, NUM_DEVICES),
)
with torch.inference_mode():
self._test_mlp_inference(device_mesh)
@with_comms
@skip_unless_torch_gpu
@parametrize("is_seq_parallel", [True, False])
def test_transformer_training(self, is_seq_parallel=False):
# Step 1: Initialize single-gpu models and optimizers.
# Disable dropout in the test since we cannot reproduce the same random
# behaviors when comparing single-gpu models with multi-gpu models.
model_args = ModelArgs(dropout_p=0.0)
model = Transformer(model_args).to(device=self.device_type)
model_tp = deepcopy(model)
self._check_module(model, model_tp)
# Step 2: Set up and execute the parallelize plan to shard the test model
# onto the device mesh.
device_mesh = DeviceMesh(self.device_type, torch.arange(0, NUM_DEVICES))
model_tp = Transformer.parallelize(model_tp, device_mesh, is_seq_parallel)
# Step 3: Run test by comparing outputs from single-gpu and multi-gpu models.
LR = 0.25
optim = torch.optim.Adam(model.parameters(), lr=LR)
optim_tp = torch.optim.Adam(model_tp.parameters(), lr=LR)
# Initialize input and make sure all ranks have the same input.
inp_size = [8, 8] # [batch_size, seq_len]
if is_seq_parallel:
assert inp_size[1] % self.world_size == 0
torch.manual_seed(0)
inp = torch.randint(model_args.vocab_size, inp_size, device=self.device_type)
# Compare outputs on the same input.
output = model(inp)
with CommDebugMode() as comm_mode:
output_tp = model_tp(inp)
self.assertEqual(output, output_tp)
if is_seq_parallel:
self.assertDictEqual(
comm_mode.get_comm_counts(),
{
c10d_functional.reduce_scatter_tensor: 6,
c10d_functional.all_gather_into_tensor: 6,
},
)
else:
self.assertDictEqual(
comm_mode.get_comm_counts(),
{
c10d_functional.all_reduce: 6,
c10d_functional.all_gather_into_tensor: 1,
},
)
# Ensure gradients are equal.
output.sum().backward()
with CommDebugMode() as comm_mode:
output_tp.sum().backward()
self._check_module(model, model_tp, check_grad=True)
if is_seq_parallel:
self.assertDictEqual(
comm_mode.get_comm_counts(),
{
c10d_functional.reduce_scatter_tensor: 5,
c10d_functional.all_gather_into_tensor: 6,
},
)
else:
self.assertDictEqual(
comm_mode.get_comm_counts(),
{
c10d_functional.all_reduce: 9,
},
)
# Ensure model weights are still the same after update.
optim.step()
from torch.distributed._tensor.experimental import implicit_replication
with implicit_replication():
with CommDebugMode() as comm_mode:
optim_tp.step()
self._check_module(model, model_tp)
if is_seq_parallel:
self.assertDictEqual(
comm_mode.get_comm_counts(),
{
c10d_functional.all_reduce: 30,
},
)
else:
self.assertDictEqual(comm_mode.get_comm_counts(), {})
# Compare outputs on another input.
torch.manual_seed(11)
inp = torch.randint(model_args.vocab_size, inp_size, device=self.device_type)
output = model(inp)
output_tp = model_tp(inp)
self.assertEqual(output, output_tp)
@with_comms
def test_weight_tying(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
# Initialize different weights for embedding and fc.
torch.manual_seed(1)
self.embedding = torch.nn.Embedding(16, 8)
torch.manual_seed(2)
self.fc = torch.nn.Linear(8, 16)
def forward(self, x):
return self.fc(self.embedding(x))
model = TestModule().to(self.device_type)
parallelize_plan = {
"embedding": ColwiseParallel(),
"fc": RowwiseParallel(),
}
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
parallelize_module(model, device_mesh, parallelize_plan)
input_size = [5]
torch.manual_seed(0)
inp = torch.randint(16, input_size, device=self.device_type)
# Without weight tying.
self.assertNotEqual(
model.embedding.weight.to_local(), model.fc.weight.to_local()
)
output = model(inp)
output.sum().backward()
self.assertNotEqual(
model.embedding.weight.grad.to_local(), model.fc.weight.grad.to_local()
)
model.zero_grad()
# With weight tying.
model.fc.weight = model.embedding.weight
self.assertEqual(model.embedding.weight, model.fc.weight)
self.assertEqual(id(model.embedding.weight), id(model.fc.weight))
output = model(inp)
output.sum().backward()
self.assertEqual(model.embedding.weight.grad, model.fc.weight.grad)
self.assertEqual(id(model.embedding.weight.grad), id(model.fc.weight.grad))
@with_comms
def test_loss_parallel(self):
device_mesh = self.build_device_mesh()
comm_mode = CommDebugMode()
channel_size, channel_dim = 16, 1
test_setup = [
(2, (8, channel_size), (8,)), # calling aten.nll_loss_forward
(3, (8, channel_size, 12), (8, 12)), # calling aten.nll_loss2d_forward
]
weight = torch.rand(channel_size, device=self.device_type)
for input_ndim, input_size, target_size in test_setup:
x = torch.rand(*input_size, device=self.device_type, requires_grad=True)
target = torch.randint(channel_size, target_size, device=self.device_type)
shard_dims = list(range(input_ndim))
reductions = ["none", "mean", "sum"]
for shard_dim, reduction in itertools.product(shard_dims, reductions):
dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)])
y = F.cross_entropy(x, target, weight, reduction=reduction)
with loss_parallel():
if shard_dim == channel_dim:
with comm_mode:
dist_y = F.cross_entropy(
dist_x, target, weight, reduction=reduction
)
self.assertEqual(comm_mode.get_total_counts(), 3)
self.assertEqual(
comm_mode.get_comm_counts()[c10d_functional.all_reduce],
3,
)
self.assertTrue(dist_y.placements[0].is_replicate())
self.assertEqual(dist_y.to_local(), y)
with comm_mode:
if reduction == "none":
y.sum().backward()
dist_y.sum().backward()
else:
y.backward()
dist_y.backward()
self.assertEqual(comm_mode.get_total_counts(), 0)
self.assertTrue(
dist_x.grad.placements[0].is_shard(shard_dim)
)
self.assertEqual(dist_x.grad.full_tensor(), x.grad)
x.grad.zero_()
else:
with self.assertRaisesRegex(
ValueError,
"loss_parallel",
):
dist_y = F.cross_entropy(
dist_x, target, reduction=reduction
)
instantiate_parametrized_tests(DistTensorParallelExampleTest)
if __name__ == "__main__":
run_tests()