mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105956 Approved by: https://github.com/wz337, https://github.com/ezyang, https://github.com/wanchaol
120 lines
4.4 KiB
Python
120 lines
4.4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|
checkpoint_wrapper,
|
|
CheckpointImpl,
|
|
)
|
|
from torch.distributed.tensor.parallel import (
|
|
PairwiseParallel,
|
|
parallelize_module,
|
|
SequenceParallel,
|
|
)
|
|
from torch.distributed.tensor.parallel.input_reshard import input_reshard
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
)
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorTestBase,
|
|
MLPModule,
|
|
NUM_DEVICES,
|
|
with_comms,
|
|
)
|
|
|
|
|
|
class DistTensorParallelExampleTest(DTensorTestBase):
|
|
def _check_module(self, m1, m2, check_grad=False, rank0_only_params=None):
|
|
rank0_only_params = [] if rank0_only_params is None else rank0_only_params
|
|
named_parameters = dict(m1.named_parameters())
|
|
for name, param_m2 in m2.named_parameters():
|
|
if self.rank != 0 and name in rank0_only_params:
|
|
continue
|
|
self.assertTrue(name in named_parameters)
|
|
param_m1 = named_parameters[name]
|
|
if check_grad:
|
|
param_m2 = param_m2.grad
|
|
param_m1 = param_m1.grad
|
|
if isinstance(param_m2, DTensor):
|
|
replicate = [Replicate()]
|
|
param_m2 = param_m2.redistribute(
|
|
device_mesh=param_m2.device_mesh, placements=replicate
|
|
).to_local()
|
|
self.assertEqual(param_m2, param_m1)
|
|
|
|
def _test_mlp_magatron_e2e(self, is_seq_parallel=False, recompute_activation=False):
|
|
inp_size = [8, 10]
|
|
# Ensure all tp ranks have same input.
|
|
rng_seed = self.rank if is_seq_parallel else 0
|
|
torch.manual_seed(rng_seed)
|
|
inp = torch.rand(*inp_size, device=self.device_type)
|
|
model = MLPModule(self.device_type)
|
|
model_tp = MLPModule(self.device_type)
|
|
|
|
# Ensure model are initialized the same way.
|
|
self._check_module(model, model_tp)
|
|
|
|
# Shard module and initialize optimizer.
|
|
LR = 0.25
|
|
device_mesh = DeviceMesh(
|
|
self.device_type,
|
|
torch.arange(0, NUM_DEVICES),
|
|
)
|
|
parallel_style = SequenceParallel() if is_seq_parallel else PairwiseParallel()
|
|
model_tp = parallelize_module(model_tp, device_mesh, parallel_style)
|
|
if recompute_activation:
|
|
model_tp = input_reshard(
|
|
checkpoint_wrapper(
|
|
model_tp, checkpoint_impl=CheckpointImpl.NO_REENTRANT
|
|
),
|
|
device_mesh,
|
|
None if is_seq_parallel else 0,
|
|
)
|
|
optim = torch.optim.SGD(model.parameters(), lr=LR)
|
|
optim_tp = torch.optim.SGD(model_tp.parameters(), lr=LR)
|
|
|
|
output = model(inp)
|
|
output_tp = model_tp(inp)
|
|
self.assertEqual(output, output_tp)
|
|
|
|
output.sum().backward()
|
|
output_tp.sum().backward()
|
|
|
|
if is_seq_parallel:
|
|
# Sum gradients from different ranks, since input
|
|
# are different across ranks for sequence parallel.
|
|
dist.all_reduce(model.net1.weight.grad)
|
|
dist.all_reduce(model.net1.bias.grad)
|
|
dist.all_reduce(model.net2.weight.grad)
|
|
dist.all_reduce(model.net2.bias.grad)
|
|
|
|
# Ensure gradients are same.
|
|
self._check_module(model, model_tp, check_grad=True)
|
|
|
|
optim.step()
|
|
optim_tp.step()
|
|
|
|
# Ensure model weights are still same after update.
|
|
# Due to the trick we use for Partial aggregation, we only check the weight when local_rank = 0.
|
|
self._check_module(model, model_tp, rank0_only_params=["net2.bias"])
|
|
|
|
inp = torch.rand(*inp_size, device=self.device_type)
|
|
output = model(inp)
|
|
output_tp = model_tp(inp)
|
|
self.assertEqual(output, output_tp)
|
|
|
|
@with_comms
|
|
@parametrize("is_seq_parallel", [True, False])
|
|
@parametrize("recompute_activation", [True, False])
|
|
def test_mlp_megatron_e2e(self, is_seq_parallel, recompute_activation):
|
|
self._test_mlp_magatron_e2e(is_seq_parallel=is_seq_parallel, recompute_activation=recompute_activation)
|
|
|
|
|
|
instantiate_parametrized_tests(DistTensorParallelExampleTest)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|