[dtensor][random] allow user to manual_seed different seed on device mesh; only sync RNG state in WORLD when manual_seed has not been called (#141223)

**Summary**
This PR proposes 4 changes to DTensor RNG management:
1. DTensor allows users to eagerly initialize the RNG tracker by calling `torch.distributed.tensor._random.manual_seed`.
2. DTensor `manual_seed` no longer checks the integrity of the `seed` argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling
```
world_mesh = init_device_mesh(
    device_type="cuda",
    mesh_shape=(2, 2, 2),
    mesh_dim_names=("pp", "dp", "tp"),
)
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd")  # this flattening is only needed if you need to call collective over this mesh
torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh)
```

In other word, if users want to call `torch.distributed.tensor._random.manual_seed`, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize.

3. `OffsetBasedRNGTracker` still performs RNG state synchronization by broadcasting the RNG state on rank 0 to `WORLD`. However, calling `torch.distributed.tensor._random.manual_seed` is an exception. In this case, no broadcast will happen.

4. Enforce that the `manual_seed` call only accept "full mesh" i.e. the DTensor RNG state on every rank must be set through the call. This makes sure that no rank has its RNG state left uninitialized and the SPMD ranks have their RNG state synchronous.

**Motivation**
tl;dr

1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel.
2. Users may want to set different seed on ranks in one device mesh.
3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it.

see detail in https://github.com/pytorch/pytorch/issues/140301

**Test**
`pytest test/distributed/_tensor/test_random_ops.py`
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141223
Approved by: https://github.com/wconstab
ghstack dependencies: #141731, #141220
This commit is contained in:
Xilun Wu 2024-11-27 15:28:16 -08:00 committed by PyTorch MergeBot
parent 7f5bc9dd87
commit 93cbb287c2
2 changed files with 96 additions and 21 deletions

View file

@ -17,6 +17,7 @@ from torch.distributed.tensor._random import (
manual_seed,
OffsetBasedRNGTracker,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
@ -102,10 +103,80 @@ class DistTensorRandomOpTest(DTensorTestBase):
@skip_unless_torch_gpu
def test_manual_seed(self):
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
manual_seed(1234, device_mesh)
self.assertEqual(1234, random._rng_tracker.get_seed("parallel-rng"))
with self.assertRaisesRegex(RuntimeError, "different seed values"):
# in the case of calling ``torch.distributed.tensor._random.manual_seed``,
# no seed synchronization should happen since we fully trust the users' input
# and will not override the value.
comm_mode = CommDebugMode()
with comm_mode:
# Test 1: set different seed on different ranks
# RNG tracker should not be initialized until DTensor ``manual_seed``
# is called.
self.assertTrue(random._rng_tracker is None)
manual_seed(self.rank, device_mesh)
# RNG tracker should already be initialized
self.assertTrue(random._rng_tracker is not None)
self.assertEqual(self.rank, random._rng_tracker.get_seed("parallel-rng"))
# Test 2: set same seed on different ranks
manual_seed(1234, device_mesh)
self.assertEqual(1234, random._rng_tracker.get_seed("parallel-rng"))
self.assertEqual(comm_mode.get_total_counts(), 0)
@with_comms
@skip_unless_torch_gpu
def test_manual_seed_submesh(self):
# the current rank is not a part of the mesh
single_rank_device_mesh = DeviceMesh(
self.device_type, [(self.rank + 1) % self.world_size]
)
with self.assertRaisesRegex(
RuntimeError,
"manual_seed requires the current rank to be a part of the device mesh",
):
manual_seed(self.rank, single_rank_device_mesh)
@with_comms
@skip_unless_torch_gpu
def test_pipeline_parallel_manual_seed(self):
# This test is to verify the `manual_seed` API works as expected in the
# pipeline parallel setting.
world_mesh = init_device_mesh(
self.device_type,
(self.world_size // 2, 2),
mesh_dim_names=("pp", "spmd"),
)
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank() # rank 0,1 = 0; rank 2,3 = 1
spmd_mesh = world_mesh["spmd"]
# set the seed for each pipeline stage to 123 + pp_rank
manual_seed(123 + pp_rank, spmd_mesh)
self.assertEqual(123 + pp_rank, random._rng_tracker.get_seed("parallel-rng"))
# mimic initializing a model weight sharded on the SPMD mesh
spmd_dtensor = torch.distributed.tensor.ones(
2 * spmd_mesh.size(), 2, device_mesh=spmd_mesh, placements=[Shard(0)]
)
torch.nn.init.normal_(spmd_dtensor)
# gather all the shards to compare initialization results
WORLD = torch.distributed.group.WORLD
assert WORLD is not None
tensor_gather = funcol.all_gather_tensor(
spmd_dtensor.to_local(),
gather_dim=0,
group=WORLD,
)
# verify the weights are initialized differently on all ranks
for other_rank in range(self.world_size):
if self.rank != other_rank:
self.assertNotEqual(
spmd_dtensor.to_local(),
tensor_gather[2 * other_rank : 2 * (other_rank + 1), :],
)
@with_comms
@skip_unless_torch_gpu

View file

@ -52,17 +52,20 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:
Args:
seed (int): The desired seed.
device_mesh (:class:`DeviceMesh`): The device mesh to set the seed.
device_mesh (:class:`DeviceMesh`): The device mesh to set the seed. It is
required that the ``device_mesh`` include the calling rank. This is
to ensure that the SPMD region maintains a synchronous RNG state, which
means no ranks should be initialized with values other than ``seed``.
Returns:
None
.. warning::
When calling this function, :func:`manual_seed` must be called from all ranks of the
default ``ProcessGroup`` even if some ranks may not be a part of the ``device_mesh``,
with the same ``seed`` value.
:func:`manual_seed` does not check the ``seed`` value correctness. Users must
ensure on their own that the value passed in is the desired ``seed`` for ranks
within ``device_mesh``.
If ``device_mesh`` is a sub-mesh and the calling rank is not a part of it,
``manual_seed`` will not set its GPU device's generator seed.
``manual_seed`` will throw an error.
Current implementation only supports a GPU device mesh.
"""
device_handle = _get_device_handle(device_mesh.device_type)
@ -71,24 +74,23 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:
f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}"
)
# allgather the seed over the default PG
object_list = [seed] * dist.get_world_size()
dist.all_gather_object(object_list, seed)
for rank, object in enumerate(object_list):
if seed != int(object):
raise RuntimeError(
f"calling manual_seed function over {device_mesh} but received different seed values on ranks:",
f"seed on rank {dist.get_rank()} is {seed}, and seed on rank {rank} is {object}!",
)
# instantiate a RNG tracker if haven't. By default DTensor uses an
# OffsetBasedRNGTracker to perform random operators.
global _rng_tracker
if not _rng_tracker:
_rng_tracker = OffsetBasedRNGTracker(device_mesh.device_type)
_rng_tracker = OffsetBasedRNGTracker(
device_mesh.device_type, run_state_sync=False
)
# the current rank is in mesh
if device_mesh.get_coordinate() is not None:
_rng_tracker._manual_seed(seed)
else:
raise RuntimeError(
"manual_seed requires the current rank to be a part of the device mesh "
"otherwise DTensor RNG state on the rank will not be initialized and "
"the behavior of DTensor random ops is undefined."
)
class _RNGStateTracker:
@ -155,11 +157,13 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
random operators.
"""
def __init__(self, device_type: str = "cuda"):
def __init__(self, device_type: str = "cuda", run_state_sync: bool = True):
super().__init__(device_type)
# synchronize RNG state using rank 0's current one
rng_state = self._device_handle.get_rng_state().to(device_type)
dist.broadcast(rng_state, 0)
if run_state_sync:
# synchronize RNG state using rank 0's current one
dist.broadcast(rng_state, 0)
self.rng_states["parallel-rng"] = rng_state.to("cpu")
def _manual_seed(self, parallel_seed: int) -> None: