mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[dtensor][random] add 1d and 2d model meta init tests (#141731)
**Summary** Added tests for model meta init on 1-d mesh (TP) and 2-d mesh (FSDP+TP). This exploits the issue where DTensor RNG failed to initialize weights differently across FSDP ranks. **Test** `pytest test/distributed/_tensor/test_random_ops.py -s -k meta_init` Pull Request resolved: https://github.com/pytorch/pytorch/pull/141731 Approved by: https://github.com/wconstab
This commit is contained in:
parent
1a32daeb17
commit
c55191f3a2
3 changed files with 116 additions and 9 deletions
|
|
@ -6,12 +6,18 @@ import itertools
|
|||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
import torch.distributed.tensor._random as random
|
||||
from torch.distributed._tensor import DeviceMesh, DTensor
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.distributed._tensor import DeviceMesh, DTensor, init_device_mesh
|
||||
from torch.distributed._tensor._utils import compute_local_shape_and_global_offset
|
||||
from torch.distributed._tensor.api import distribute_tensor
|
||||
from torch.distributed._tensor.placement_types import Replicate, Shard
|
||||
from torch.distributed.distributed_c10d import broadcast_object_list
|
||||
from torch.distributed.tensor._random import is_rng_supported_mesh, manual_seed
|
||||
from torch.distributed.tensor._random import (
|
||||
is_rng_supported_mesh,
|
||||
manual_seed,
|
||||
TensorParallelRNGTracker,
|
||||
)
|
||||
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
|
|
@ -101,6 +107,103 @@ class DistTensorRandomOpTest(DTensorTestBase):
|
|||
with self.assertRaisesRegex(RuntimeError, "different seed values"):
|
||||
manual_seed(self.rank, device_mesh)
|
||||
|
||||
@with_comms
|
||||
@skip_unless_torch_gpu
|
||||
def test_tp_model_meta_init(self):
|
||||
# initialize the 1-d device mesh for TP
|
||||
tp_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
|
||||
|
||||
# model meta init
|
||||
with torch.device("meta"):
|
||||
model = torch.nn.Linear(self.world_size, self.world_size, bias=False)
|
||||
self.assertEqual(model.weight.device, torch.device("meta"))
|
||||
parallelize_module(model, tp_mesh, ColwiseParallel())
|
||||
if random._rng_tracker is not None:
|
||||
random._rng_tracker.distribute_region_enabled = True
|
||||
|
||||
self.assertEqual(model.weight.device, torch.device("meta"))
|
||||
|
||||
# actual initialization
|
||||
device = torch.device("cuda", torch.cuda.current_device())
|
||||
model.to_empty(device=device)
|
||||
model.reset_parameters()
|
||||
self.assertTrue(
|
||||
random._rng_tracker is not None
|
||||
and isinstance(random._rng_tracker, TensorParallelRNGTracker)
|
||||
)
|
||||
self.assertEqual(model.weight.device, device)
|
||||
assert isinstance(model.weight, DTensor)
|
||||
|
||||
# gather all the shards to compare initialization results
|
||||
WORLD = torch.distributed.group.WORLD
|
||||
assert WORLD is not None
|
||||
weight_local = model.weight.to_local()
|
||||
weight_gather = funcol.all_gather_tensor(
|
||||
weight_local,
|
||||
gather_dim=0,
|
||||
group=WORLD,
|
||||
)
|
||||
|
||||
# verify the weights are initialized differently on all ranks
|
||||
for other_rank in range(self.world_size):
|
||||
if self.rank != other_rank:
|
||||
self.assertNotEqual(
|
||||
weight_local,
|
||||
weight_gather[other_rank : other_rank + 1, :],
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@skip_unless_torch_gpu
|
||||
def test_fsdp_tp_model_meta_init(self):
|
||||
# initialize the 2-d device mesh
|
||||
global_mesh = init_device_mesh(
|
||||
self.device_type,
|
||||
mesh_shape=(self.world_size // 2, 2),
|
||||
mesh_dim_names=("dp", "tp"),
|
||||
)
|
||||
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
|
||||
|
||||
# model meta init
|
||||
with torch.device("meta"):
|
||||
model = torch.nn.Linear(self.world_size, self.world_size, bias=False)
|
||||
self.assertEqual(model.weight.device, torch.device("meta"))
|
||||
parallelize_module(model, tp_mesh, ColwiseParallel())
|
||||
if random._rng_tracker is not None:
|
||||
random._rng_tracker.distribute_region_enabled = True
|
||||
|
||||
fully_shard(model, mesh=dp_mesh)
|
||||
self.assertEqual(model.weight.device, torch.device("meta"))
|
||||
|
||||
# actual initialization
|
||||
device = torch.device("cuda", torch.cuda.current_device())
|
||||
model.to_empty(device=device)
|
||||
model.reset_parameters()
|
||||
self.assertTrue(
|
||||
random._rng_tracker is not None
|
||||
and isinstance(random._rng_tracker, TensorParallelRNGTracker)
|
||||
)
|
||||
self.assertEqual(model.weight.device, device)
|
||||
assert isinstance(model.weight, DTensor)
|
||||
|
||||
# gather all the shards to compare initialization results
|
||||
WORLD = torch.distributed.group.WORLD
|
||||
assert WORLD is not None
|
||||
weight_local = model.weight.to_local()
|
||||
weight_gather = funcol.all_gather_tensor(
|
||||
weight_local,
|
||||
gather_dim=0,
|
||||
group=WORLD,
|
||||
)
|
||||
|
||||
# verify the weights are initialized differently on all ranks
|
||||
with self.assertRaisesRegex(AssertionError, "AssertionError not raised"):
|
||||
for other_rank in range(self.world_size):
|
||||
if self.rank != other_rank:
|
||||
self.assertNotEqual(
|
||||
weight_local,
|
||||
weight_gather[other_rank : other_rank + 1, :],
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@skip_unless_torch_gpu
|
||||
def test_deterministic_dropout_1d(self):
|
||||
|
|
|
|||
|
|
@ -401,9 +401,11 @@ class OpDispatcher:
|
|||
mesh,
|
||||
OpSchema(
|
||||
op_call,
|
||||
pytree.tree_unflatten(args_schema, args_spec)
|
||||
if args_spec
|
||||
else tuple(args_schema),
|
||||
(
|
||||
pytree.tree_unflatten(args_schema, args_spec)
|
||||
if args_spec
|
||||
else tuple(args_schema)
|
||||
),
|
||||
kwargs_schema,
|
||||
schema_info=runtime_schema_info,
|
||||
),
|
||||
|
|
|
|||
|
|
@ -145,8 +145,8 @@ class _RNGStateTracker:
|
|||
return int(seed_tensor.item())
|
||||
|
||||
def set_seed(self, name: str, seed: int) -> None:
|
||||
seed_tensor = torch.tensor([seed]).view(torch.uint8)
|
||||
offset_tensor = torch.tensor([0]).view(torch.uint8)
|
||||
seed_tensor = torch.tensor([seed], device="cpu").view(torch.uint8)
|
||||
offset_tensor = torch.tensor([0], device="cpu").view(torch.uint8)
|
||||
self.rng_states[name] = torch.cat([seed_tensor, offset_tensor])
|
||||
|
||||
def _distribute_region(self, spec: DTensorSpec):
|
||||
|
|
@ -208,7 +208,7 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
|
|||
)
|
||||
|
||||
seed_tensor = (self.rng_states[name])[0:8]
|
||||
offset_tensor = torch.tensor([offset]).view(torch.uint8)
|
||||
offset_tensor = torch.tensor([offset], device="cpu").view(torch.uint8)
|
||||
self.rng_states[name] = torch.cat([seed_tensor, offset_tensor])
|
||||
|
||||
def _set_pre_op_offset(self, spec: DTensorSpec) -> None:
|
||||
|
|
@ -343,7 +343,9 @@ class TensorParallelRNGTracker(_RNGStateTracker):
|
|||
def __init__(self, device_type: str = "cuda"):
|
||||
super().__init__(device_type)
|
||||
# copy the default RNG state
|
||||
self.rng_states["tensor-parallel-rng"] = self._device_handle.get_rng_state()
|
||||
self.rng_states["tensor-parallel-rng"] = self._device_handle.get_rng_state().to(
|
||||
"cpu"
|
||||
)
|
||||
|
||||
def _manual_seed(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Reference in a new issue