diff --git a/test/distributed/_tensor/test_random_ops.py b/test/distributed/_tensor/test_random_ops.py index 7c86af5ed96..8d9af07d953 100644 --- a/test/distributed/_tensor/test_random_ops.py +++ b/test/distributed/_tensor/test_random_ops.py @@ -17,6 +17,7 @@ from torch.distributed.tensor._random import ( manual_seed, OffsetBasedRNGTracker, ) +from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( @@ -102,10 +103,80 @@ class DistTensorRandomOpTest(DTensorTestBase): @skip_unless_torch_gpu def test_manual_seed(self): device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) - manual_seed(1234, device_mesh) - self.assertEqual(1234, random._rng_tracker.get_seed("parallel-rng")) - with self.assertRaisesRegex(RuntimeError, "different seed values"): + + # in the case of calling ``torch.distributed.tensor._random.manual_seed``, + # no seed synchronization should happen since we fully trust the users' input + # and will not override the value. + comm_mode = CommDebugMode() + with comm_mode: + # Test 1: set different seed on different ranks + # RNG tracker should not be initialized until DTensor ``manual_seed`` + # is called. + self.assertTrue(random._rng_tracker is None) manual_seed(self.rank, device_mesh) + # RNG tracker should already be initialized + self.assertTrue(random._rng_tracker is not None) + self.assertEqual(self.rank, random._rng_tracker.get_seed("parallel-rng")) + + # Test 2: set same seed on different ranks + manual_seed(1234, device_mesh) + self.assertEqual(1234, random._rng_tracker.get_seed("parallel-rng")) + + self.assertEqual(comm_mode.get_total_counts(), 0) + + @with_comms + @skip_unless_torch_gpu + def test_manual_seed_submesh(self): + # the current rank is not a part of the mesh + single_rank_device_mesh = DeviceMesh( + self.device_type, [(self.rank + 1) % self.world_size] + ) + with self.assertRaisesRegex( + RuntimeError, + "manual_seed requires the current rank to be a part of the device mesh", + ): + manual_seed(self.rank, single_rank_device_mesh) + + @with_comms + @skip_unless_torch_gpu + def test_pipeline_parallel_manual_seed(self): + # This test is to verify the `manual_seed` API works as expected in the + # pipeline parallel setting. + world_mesh = init_device_mesh( + self.device_type, + (self.world_size // 2, 2), + mesh_dim_names=("pp", "spmd"), + ) + pp_mesh = world_mesh["pp"] + pp_rank = pp_mesh.get_local_rank() # rank 0,1 = 0; rank 2,3 = 1 + spmd_mesh = world_mesh["spmd"] + + # set the seed for each pipeline stage to 123 + pp_rank + manual_seed(123 + pp_rank, spmd_mesh) + self.assertEqual(123 + pp_rank, random._rng_tracker.get_seed("parallel-rng")) + + # mimic initializing a model weight sharded on the SPMD mesh + spmd_dtensor = torch.distributed.tensor.ones( + 2 * spmd_mesh.size(), 2, device_mesh=spmd_mesh, placements=[Shard(0)] + ) + torch.nn.init.normal_(spmd_dtensor) + + # gather all the shards to compare initialization results + WORLD = torch.distributed.group.WORLD + assert WORLD is not None + tensor_gather = funcol.all_gather_tensor( + spmd_dtensor.to_local(), + gather_dim=0, + group=WORLD, + ) + + # verify the weights are initialized differently on all ranks + for other_rank in range(self.world_size): + if self.rank != other_rank: + self.assertNotEqual( + spmd_dtensor.to_local(), + tensor_gather[2 * other_rank : 2 * (other_rank + 1), :], + ) @with_comms @skip_unless_torch_gpu diff --git a/torch/distributed/tensor/_random.py b/torch/distributed/tensor/_random.py index 4b420d81779..402b0160282 100644 --- a/torch/distributed/tensor/_random.py +++ b/torch/distributed/tensor/_random.py @@ -52,17 +52,20 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None: Args: seed (int): The desired seed. - device_mesh (:class:`DeviceMesh`): The device mesh to set the seed. + device_mesh (:class:`DeviceMesh`): The device mesh to set the seed. It is + required that the ``device_mesh`` include the calling rank. This is + to ensure that the SPMD region maintains a synchronous RNG state, which + means no ranks should be initialized with values other than ``seed``. Returns: None .. warning:: - When calling this function, :func:`manual_seed` must be called from all ranks of the - default ``ProcessGroup`` even if some ranks may not be a part of the ``device_mesh``, - with the same ``seed`` value. + :func:`manual_seed` does not check the ``seed`` value correctness. Users must + ensure on their own that the value passed in is the desired ``seed`` for ranks + within ``device_mesh``. If ``device_mesh`` is a sub-mesh and the calling rank is not a part of it, - ``manual_seed`` will not set its GPU device's generator seed. + ``manual_seed`` will throw an error. Current implementation only supports a GPU device mesh. """ device_handle = _get_device_handle(device_mesh.device_type) @@ -71,24 +74,23 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None: f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}" ) - # allgather the seed over the default PG - object_list = [seed] * dist.get_world_size() - dist.all_gather_object(object_list, seed) - for rank, object in enumerate(object_list): - if seed != int(object): - raise RuntimeError( - f"calling manual_seed function over {device_mesh} but received different seed values on ranks:", - f"seed on rank {dist.get_rank()} is {seed}, and seed on rank {rank} is {object}!", - ) # instantiate a RNG tracker if haven't. By default DTensor uses an # OffsetBasedRNGTracker to perform random operators. global _rng_tracker if not _rng_tracker: - _rng_tracker = OffsetBasedRNGTracker(device_mesh.device_type) + _rng_tracker = OffsetBasedRNGTracker( + device_mesh.device_type, run_state_sync=False + ) # the current rank is in mesh if device_mesh.get_coordinate() is not None: _rng_tracker._manual_seed(seed) + else: + raise RuntimeError( + "manual_seed requires the current rank to be a part of the device mesh " + "otherwise DTensor RNG state on the rank will not be initialized and " + "the behavior of DTensor random ops is undefined." + ) class _RNGStateTracker: @@ -155,11 +157,13 @@ class OffsetBasedRNGTracker(_RNGStateTracker): random operators. """ - def __init__(self, device_type: str = "cuda"): + def __init__(self, device_type: str = "cuda", run_state_sync: bool = True): super().__init__(device_type) - # synchronize RNG state using rank 0's current one rng_state = self._device_handle.get_rng_state().to(device_type) - dist.broadcast(rng_state, 0) + if run_state_sync: + # synchronize RNG state using rank 0's current one + dist.broadcast(rng_state, 0) + self.rng_states["parallel-rng"] = rng_state.to("cpu") def _manual_seed(self, parallel_seed: int) -> None: