[dtensor][random] add 1d and 2d model meta init tests (#141731)

**Summary**
Added tests for model meta init on 1-d mesh (TP) and 2-d mesh (FSDP+TP). This exploits the issue where DTensor RNG failed to initialize weights differently across FSDP ranks.

**Test**
`pytest test/distributed/_tensor/test_random_ops.py -s -k meta_init`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141731
Approved by: https://github.com/wconstab
This commit is contained in:
Xilun Wu 2024-11-27 15:28:15 -08:00 committed by PyTorch MergeBot
parent 1a32daeb17
commit c55191f3a2
3 changed files with 116 additions and 9 deletions

View file

@ -6,12 +6,18 @@ import itertools
import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor._random as random
from torch.distributed._tensor import DeviceMesh, DTensor
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._tensor import DeviceMesh, DTensor, init_device_mesh
from torch.distributed._tensor._utils import compute_local_shape_and_global_offset
from torch.distributed._tensor.api import distribute_tensor
from torch.distributed._tensor.placement_types import Replicate, Shard
from torch.distributed.distributed_c10d import broadcast_object_list
from torch.distributed.tensor._random import is_rng_supported_mesh, manual_seed
from torch.distributed.tensor._random import (
is_rng_supported_mesh,
manual_seed,
TensorParallelRNGTracker,
)
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
@ -101,6 +107,103 @@ class DistTensorRandomOpTest(DTensorTestBase):
with self.assertRaisesRegex(RuntimeError, "different seed values"):
manual_seed(self.rank, device_mesh)
@with_comms
@skip_unless_torch_gpu
def test_tp_model_meta_init(self):
# initialize the 1-d device mesh for TP
tp_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
# model meta init
with torch.device("meta"):
model = torch.nn.Linear(self.world_size, self.world_size, bias=False)
self.assertEqual(model.weight.device, torch.device("meta"))
parallelize_module(model, tp_mesh, ColwiseParallel())
if random._rng_tracker is not None:
random._rng_tracker.distribute_region_enabled = True
self.assertEqual(model.weight.device, torch.device("meta"))
# actual initialization
device = torch.device("cuda", torch.cuda.current_device())
model.to_empty(device=device)
model.reset_parameters()
self.assertTrue(
random._rng_tracker is not None
and isinstance(random._rng_tracker, TensorParallelRNGTracker)
)
self.assertEqual(model.weight.device, device)
assert isinstance(model.weight, DTensor)
# gather all the shards to compare initialization results
WORLD = torch.distributed.group.WORLD
assert WORLD is not None
weight_local = model.weight.to_local()
weight_gather = funcol.all_gather_tensor(
weight_local,
gather_dim=0,
group=WORLD,
)
# verify the weights are initialized differently on all ranks
for other_rank in range(self.world_size):
if self.rank != other_rank:
self.assertNotEqual(
weight_local,
weight_gather[other_rank : other_rank + 1, :],
)
@with_comms
@skip_unless_torch_gpu
def test_fsdp_tp_model_meta_init(self):
# initialize the 2-d device mesh
global_mesh = init_device_mesh(
self.device_type,
mesh_shape=(self.world_size // 2, 2),
mesh_dim_names=("dp", "tp"),
)
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
# model meta init
with torch.device("meta"):
model = torch.nn.Linear(self.world_size, self.world_size, bias=False)
self.assertEqual(model.weight.device, torch.device("meta"))
parallelize_module(model, tp_mesh, ColwiseParallel())
if random._rng_tracker is not None:
random._rng_tracker.distribute_region_enabled = True
fully_shard(model, mesh=dp_mesh)
self.assertEqual(model.weight.device, torch.device("meta"))
# actual initialization
device = torch.device("cuda", torch.cuda.current_device())
model.to_empty(device=device)
model.reset_parameters()
self.assertTrue(
random._rng_tracker is not None
and isinstance(random._rng_tracker, TensorParallelRNGTracker)
)
self.assertEqual(model.weight.device, device)
assert isinstance(model.weight, DTensor)
# gather all the shards to compare initialization results
WORLD = torch.distributed.group.WORLD
assert WORLD is not None
weight_local = model.weight.to_local()
weight_gather = funcol.all_gather_tensor(
weight_local,
gather_dim=0,
group=WORLD,
)
# verify the weights are initialized differently on all ranks
with self.assertRaisesRegex(AssertionError, "AssertionError not raised"):
for other_rank in range(self.world_size):
if self.rank != other_rank:
self.assertNotEqual(
weight_local,
weight_gather[other_rank : other_rank + 1, :],
)
@with_comms
@skip_unless_torch_gpu
def test_deterministic_dropout_1d(self):

View file

@ -401,9 +401,11 @@ class OpDispatcher:
mesh,
OpSchema(
op_call,
pytree.tree_unflatten(args_schema, args_spec)
if args_spec
else tuple(args_schema),
(
pytree.tree_unflatten(args_schema, args_spec)
if args_spec
else tuple(args_schema)
),
kwargs_schema,
schema_info=runtime_schema_info,
),

View file

@ -145,8 +145,8 @@ class _RNGStateTracker:
return int(seed_tensor.item())
def set_seed(self, name: str, seed: int) -> None:
seed_tensor = torch.tensor([seed]).view(torch.uint8)
offset_tensor = torch.tensor([0]).view(torch.uint8)
seed_tensor = torch.tensor([seed], device="cpu").view(torch.uint8)
offset_tensor = torch.tensor([0], device="cpu").view(torch.uint8)
self.rng_states[name] = torch.cat([seed_tensor, offset_tensor])
def _distribute_region(self, spec: DTensorSpec):
@ -208,7 +208,7 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
)
seed_tensor = (self.rng_states[name])[0:8]
offset_tensor = torch.tensor([offset]).view(torch.uint8)
offset_tensor = torch.tensor([offset], device="cpu").view(torch.uint8)
self.rng_states[name] = torch.cat([seed_tensor, offset_tensor])
def _set_pre_op_offset(self, spec: DTensorSpec) -> None:
@ -343,7 +343,9 @@ class TensorParallelRNGTracker(_RNGStateTracker):
def __init__(self, device_type: str = "cuda"):
super().__init__(device_type)
# copy the default RNG state
self.rng_states["tensor-parallel-rng"] = self._device_handle.get_rng_state()
self.rng_states["tensor-parallel-rng"] = self._device_handle.get_rng_state().to(
"cpu"
)
def _manual_seed(
self,