diff --git a/test/distributed/_tensor/test_random_ops.py b/test/distributed/_tensor/test_random_ops.py index 6964b412537..378af6d391e 100644 --- a/test/distributed/_tensor/test_random_ops.py +++ b/test/distributed/_tensor/test_random_ops.py @@ -6,12 +6,18 @@ import itertools import torch import torch.distributed._functional_collectives as funcol import torch.distributed.tensor._random as random -from torch.distributed._tensor import DeviceMesh, DTensor +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed._tensor import DeviceMesh, DTensor, init_device_mesh from torch.distributed._tensor._utils import compute_local_shape_and_global_offset from torch.distributed._tensor.api import distribute_tensor from torch.distributed._tensor.placement_types import Replicate, Shard from torch.distributed.distributed_c10d import broadcast_object_list -from torch.distributed.tensor._random import is_rng_supported_mesh, manual_seed +from torch.distributed.tensor._random import ( + is_rng_supported_mesh, + manual_seed, + TensorParallelRNGTracker, +) +from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -101,6 +107,103 @@ class DistTensorRandomOpTest(DTensorTestBase): with self.assertRaisesRegex(RuntimeError, "different seed values"): manual_seed(self.rank, device_mesh) + @with_comms + @skip_unless_torch_gpu + def test_tp_model_meta_init(self): + # initialize the 1-d device mesh for TP + tp_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,)) + + # model meta init + with torch.device("meta"): + model = torch.nn.Linear(self.world_size, self.world_size, bias=False) + self.assertEqual(model.weight.device, torch.device("meta")) + parallelize_module(model, tp_mesh, ColwiseParallel()) + if random._rng_tracker is not None: + random._rng_tracker.distribute_region_enabled = True + + self.assertEqual(model.weight.device, torch.device("meta")) + + # actual initialization + device = torch.device("cuda", torch.cuda.current_device()) + model.to_empty(device=device) + model.reset_parameters() + self.assertTrue( + random._rng_tracker is not None + and isinstance(random._rng_tracker, TensorParallelRNGTracker) + ) + self.assertEqual(model.weight.device, device) + assert isinstance(model.weight, DTensor) + + # gather all the shards to compare initialization results + WORLD = torch.distributed.group.WORLD + assert WORLD is not None + weight_local = model.weight.to_local() + weight_gather = funcol.all_gather_tensor( + weight_local, + gather_dim=0, + group=WORLD, + ) + + # verify the weights are initialized differently on all ranks + for other_rank in range(self.world_size): + if self.rank != other_rank: + self.assertNotEqual( + weight_local, + weight_gather[other_rank : other_rank + 1, :], + ) + + @with_comms + @skip_unless_torch_gpu + def test_fsdp_tp_model_meta_init(self): + # initialize the 2-d device mesh + global_mesh = init_device_mesh( + self.device_type, + mesh_shape=(self.world_size // 2, 2), + mesh_dim_names=("dp", "tp"), + ) + dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"] + + # model meta init + with torch.device("meta"): + model = torch.nn.Linear(self.world_size, self.world_size, bias=False) + self.assertEqual(model.weight.device, torch.device("meta")) + parallelize_module(model, tp_mesh, ColwiseParallel()) + if random._rng_tracker is not None: + random._rng_tracker.distribute_region_enabled = True + + fully_shard(model, mesh=dp_mesh) + self.assertEqual(model.weight.device, torch.device("meta")) + + # actual initialization + device = torch.device("cuda", torch.cuda.current_device()) + model.to_empty(device=device) + model.reset_parameters() + self.assertTrue( + random._rng_tracker is not None + and isinstance(random._rng_tracker, TensorParallelRNGTracker) + ) + self.assertEqual(model.weight.device, device) + assert isinstance(model.weight, DTensor) + + # gather all the shards to compare initialization results + WORLD = torch.distributed.group.WORLD + assert WORLD is not None + weight_local = model.weight.to_local() + weight_gather = funcol.all_gather_tensor( + weight_local, + gather_dim=0, + group=WORLD, + ) + + # verify the weights are initialized differently on all ranks + with self.assertRaisesRegex(AssertionError, "AssertionError not raised"): + for other_rank in range(self.world_size): + if self.rank != other_rank: + self.assertNotEqual( + weight_local, + weight_gather[other_rank : other_rank + 1, :], + ) + @with_comms @skip_unless_torch_gpu def test_deterministic_dropout_1d(self): diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index 4383918ca35..331471db457 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -401,9 +401,11 @@ class OpDispatcher: mesh, OpSchema( op_call, - pytree.tree_unflatten(args_schema, args_spec) - if args_spec - else tuple(args_schema), + ( + pytree.tree_unflatten(args_schema, args_spec) + if args_spec + else tuple(args_schema) + ), kwargs_schema, schema_info=runtime_schema_info, ), diff --git a/torch/distributed/tensor/_random.py b/torch/distributed/tensor/_random.py index db4b2832548..e3f5176dd2b 100644 --- a/torch/distributed/tensor/_random.py +++ b/torch/distributed/tensor/_random.py @@ -145,8 +145,8 @@ class _RNGStateTracker: return int(seed_tensor.item()) def set_seed(self, name: str, seed: int) -> None: - seed_tensor = torch.tensor([seed]).view(torch.uint8) - offset_tensor = torch.tensor([0]).view(torch.uint8) + seed_tensor = torch.tensor([seed], device="cpu").view(torch.uint8) + offset_tensor = torch.tensor([0], device="cpu").view(torch.uint8) self.rng_states[name] = torch.cat([seed_tensor, offset_tensor]) def _distribute_region(self, spec: DTensorSpec): @@ -208,7 +208,7 @@ class OffsetBasedRNGTracker(_RNGStateTracker): ) seed_tensor = (self.rng_states[name])[0:8] - offset_tensor = torch.tensor([offset]).view(torch.uint8) + offset_tensor = torch.tensor([offset], device="cpu").view(torch.uint8) self.rng_states[name] = torch.cat([seed_tensor, offset_tensor]) def _set_pre_op_offset(self, spec: DTensorSpec) -> None: @@ -343,7 +343,9 @@ class TensorParallelRNGTracker(_RNGStateTracker): def __init__(self, device_type: str = "cuda"): super().__init__(device_type) # copy the default RNG state - self.rng_states["tensor-parallel-rng"] = self._device_handle.get_rng_state() + self.rng_states["tensor-parallel-rng"] = self._device_handle.get_rng_state().to( + "cpu" + ) def _manual_seed( self,