mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[dtensor][random] allow user to manual_seed different seed on device mesh; only sync RNG state in WORLD when manual_seed has not been called (#141223)
**Summary**
This PR proposes 4 changes to DTensor RNG management:
1. DTensor allows users to eagerly initialize the RNG tracker by calling `torch.distributed.tensor._random.manual_seed`.
2. DTensor `manual_seed` no longer checks the integrity of the `seed` argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling
```
world_mesh = init_device_mesh(
device_type="cuda",
mesh_shape=(2, 2, 2),
mesh_dim_names=("pp", "dp", "tp"),
)
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd") # this flattening is only needed if you need to call collective over this mesh
torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh)
```
In other word, if users want to call `torch.distributed.tensor._random.manual_seed`, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize.
3. `OffsetBasedRNGTracker` still performs RNG state synchronization by broadcasting the RNG state on rank 0 to `WORLD`. However, calling `torch.distributed.tensor._random.manual_seed` is an exception. In this case, no broadcast will happen.
4. Enforce that the `manual_seed` call only accept "full mesh" i.e. the DTensor RNG state on every rank must be set through the call. This makes sure that no rank has its RNG state left uninitialized and the SPMD ranks have their RNG state synchronous.
**Motivation**
tl;dr
1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel.
2. Users may want to set different seed on ranks in one device mesh.
3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it.
see detail in https://github.com/pytorch/pytorch/issues/140301
**Test**
`pytest test/distributed/_tensor/test_random_ops.py`
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141223
Approved by: https://github.com/wconstab
ghstack dependencies: #141731, #141220
This commit is contained in:
parent
7f5bc9dd87
commit
93cbb287c2
2 changed files with 96 additions and 21 deletions
|
|
@ -17,6 +17,7 @@ from torch.distributed.tensor._random import (
|
|||
manual_seed,
|
||||
OffsetBasedRNGTracker,
|
||||
)
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
|
|
@ -102,10 +103,80 @@ class DistTensorRandomOpTest(DTensorTestBase):
|
|||
@skip_unless_torch_gpu
|
||||
def test_manual_seed(self):
|
||||
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
manual_seed(1234, device_mesh)
|
||||
self.assertEqual(1234, random._rng_tracker.get_seed("parallel-rng"))
|
||||
with self.assertRaisesRegex(RuntimeError, "different seed values"):
|
||||
|
||||
# in the case of calling ``torch.distributed.tensor._random.manual_seed``,
|
||||
# no seed synchronization should happen since we fully trust the users' input
|
||||
# and will not override the value.
|
||||
comm_mode = CommDebugMode()
|
||||
with comm_mode:
|
||||
# Test 1: set different seed on different ranks
|
||||
# RNG tracker should not be initialized until DTensor ``manual_seed``
|
||||
# is called.
|
||||
self.assertTrue(random._rng_tracker is None)
|
||||
manual_seed(self.rank, device_mesh)
|
||||
# RNG tracker should already be initialized
|
||||
self.assertTrue(random._rng_tracker is not None)
|
||||
self.assertEqual(self.rank, random._rng_tracker.get_seed("parallel-rng"))
|
||||
|
||||
# Test 2: set same seed on different ranks
|
||||
manual_seed(1234, device_mesh)
|
||||
self.assertEqual(1234, random._rng_tracker.get_seed("parallel-rng"))
|
||||
|
||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||
|
||||
@with_comms
|
||||
@skip_unless_torch_gpu
|
||||
def test_manual_seed_submesh(self):
|
||||
# the current rank is not a part of the mesh
|
||||
single_rank_device_mesh = DeviceMesh(
|
||||
self.device_type, [(self.rank + 1) % self.world_size]
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"manual_seed requires the current rank to be a part of the device mesh",
|
||||
):
|
||||
manual_seed(self.rank, single_rank_device_mesh)
|
||||
|
||||
@with_comms
|
||||
@skip_unless_torch_gpu
|
||||
def test_pipeline_parallel_manual_seed(self):
|
||||
# This test is to verify the `manual_seed` API works as expected in the
|
||||
# pipeline parallel setting.
|
||||
world_mesh = init_device_mesh(
|
||||
self.device_type,
|
||||
(self.world_size // 2, 2),
|
||||
mesh_dim_names=("pp", "spmd"),
|
||||
)
|
||||
pp_mesh = world_mesh["pp"]
|
||||
pp_rank = pp_mesh.get_local_rank() # rank 0,1 = 0; rank 2,3 = 1
|
||||
spmd_mesh = world_mesh["spmd"]
|
||||
|
||||
# set the seed for each pipeline stage to 123 + pp_rank
|
||||
manual_seed(123 + pp_rank, spmd_mesh)
|
||||
self.assertEqual(123 + pp_rank, random._rng_tracker.get_seed("parallel-rng"))
|
||||
|
||||
# mimic initializing a model weight sharded on the SPMD mesh
|
||||
spmd_dtensor = torch.distributed.tensor.ones(
|
||||
2 * spmd_mesh.size(), 2, device_mesh=spmd_mesh, placements=[Shard(0)]
|
||||
)
|
||||
torch.nn.init.normal_(spmd_dtensor)
|
||||
|
||||
# gather all the shards to compare initialization results
|
||||
WORLD = torch.distributed.group.WORLD
|
||||
assert WORLD is not None
|
||||
tensor_gather = funcol.all_gather_tensor(
|
||||
spmd_dtensor.to_local(),
|
||||
gather_dim=0,
|
||||
group=WORLD,
|
||||
)
|
||||
|
||||
# verify the weights are initialized differently on all ranks
|
||||
for other_rank in range(self.world_size):
|
||||
if self.rank != other_rank:
|
||||
self.assertNotEqual(
|
||||
spmd_dtensor.to_local(),
|
||||
tensor_gather[2 * other_rank : 2 * (other_rank + 1), :],
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@skip_unless_torch_gpu
|
||||
|
|
|
|||
|
|
@ -52,17 +52,20 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:
|
|||
|
||||
Args:
|
||||
seed (int): The desired seed.
|
||||
device_mesh (:class:`DeviceMesh`): The device mesh to set the seed.
|
||||
device_mesh (:class:`DeviceMesh`): The device mesh to set the seed. It is
|
||||
required that the ``device_mesh`` include the calling rank. This is
|
||||
to ensure that the SPMD region maintains a synchronous RNG state, which
|
||||
means no ranks should be initialized with values other than ``seed``.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
.. warning::
|
||||
When calling this function, :func:`manual_seed` must be called from all ranks of the
|
||||
default ``ProcessGroup`` even if some ranks may not be a part of the ``device_mesh``,
|
||||
with the same ``seed`` value.
|
||||
:func:`manual_seed` does not check the ``seed`` value correctness. Users must
|
||||
ensure on their own that the value passed in is the desired ``seed`` for ranks
|
||||
within ``device_mesh``.
|
||||
If ``device_mesh`` is a sub-mesh and the calling rank is not a part of it,
|
||||
``manual_seed`` will not set its GPU device's generator seed.
|
||||
``manual_seed`` will throw an error.
|
||||
Current implementation only supports a GPU device mesh.
|
||||
"""
|
||||
device_handle = _get_device_handle(device_mesh.device_type)
|
||||
|
|
@ -71,24 +74,23 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:
|
|||
f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}"
|
||||
)
|
||||
|
||||
# allgather the seed over the default PG
|
||||
object_list = [seed] * dist.get_world_size()
|
||||
dist.all_gather_object(object_list, seed)
|
||||
for rank, object in enumerate(object_list):
|
||||
if seed != int(object):
|
||||
raise RuntimeError(
|
||||
f"calling manual_seed function over {device_mesh} but received different seed values on ranks:",
|
||||
f"seed on rank {dist.get_rank()} is {seed}, and seed on rank {rank} is {object}!",
|
||||
)
|
||||
# instantiate a RNG tracker if haven't. By default DTensor uses an
|
||||
# OffsetBasedRNGTracker to perform random operators.
|
||||
global _rng_tracker
|
||||
if not _rng_tracker:
|
||||
_rng_tracker = OffsetBasedRNGTracker(device_mesh.device_type)
|
||||
_rng_tracker = OffsetBasedRNGTracker(
|
||||
device_mesh.device_type, run_state_sync=False
|
||||
)
|
||||
|
||||
# the current rank is in mesh
|
||||
if device_mesh.get_coordinate() is not None:
|
||||
_rng_tracker._manual_seed(seed)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"manual_seed requires the current rank to be a part of the device mesh "
|
||||
"otherwise DTensor RNG state on the rank will not be initialized and "
|
||||
"the behavior of DTensor random ops is undefined."
|
||||
)
|
||||
|
||||
|
||||
class _RNGStateTracker:
|
||||
|
|
@ -155,11 +157,13 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
|
|||
random operators.
|
||||
"""
|
||||
|
||||
def __init__(self, device_type: str = "cuda"):
|
||||
def __init__(self, device_type: str = "cuda", run_state_sync: bool = True):
|
||||
super().__init__(device_type)
|
||||
# synchronize RNG state using rank 0's current one
|
||||
rng_state = self._device_handle.get_rng_state().to(device_type)
|
||||
dist.broadcast(rng_state, 0)
|
||||
if run_state_sync:
|
||||
# synchronize RNG state using rank 0's current one
|
||||
dist.broadcast(rng_state, 0)
|
||||
|
||||
self.rng_states["parallel-rng"] = rng_state.to("cpu")
|
||||
|
||||
def _manual_seed(self, parallel_seed: int) -> None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue