mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[SymmetricMemory] introduce user-facing APIs empty() and rendezvous() (#139677)
Previously `SymmetricMemory` only had private pybind APIs:
```python
from torch.distributed._symmetric_memory import _SymmetricMemory
t = _SymmetricMemory.empty_strided_p2p(
size=(64,),
stride=(1,),
dtype=torch.float32,
device=device,
)
symm_mem_hdl = _SymmetricMemory.rendezvous(t, group_name=group.group_name)
```
This PR introduces user-facing APIs empty() and rendezvous():
```python
import torch.distributed._symmetric_memory as symm_mem
t = symm_mem.empty(64, device="cuda")
symm_mem_hdl = symm_mem.rendezvous(t, group_name=group.group_name)
```
Notable differences compared to the pybind APIs:
- `empty()` now resembles `torch.empty()`:
- shape can either be an integer sequence or pack
- no need to/can't specify stride anymore
- device can either be `torch.device` or string
- `group_name` needs to be specified at rendezvous time as opposed to allocation time. See https://github.com/pytorch/pytorch/pull/139529 for the rationales. I feel the new semantic is superior, hence enforcing it in the public API.
- Currently, the pybind API still support specifying `group_name` at rendezvous time.
This PR does not change the behavior of the pybind APIs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139677
Approved by: https://github.com/lw
ghstack dependencies: #139529
This commit is contained in:
parent
9f4af6b4e6
commit
5a7e147ef3
2 changed files with 183 additions and 85 deletions
|
|
@ -5,6 +5,7 @@ from unittest import skipIf
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._symmetric_memory as symm_mem
|
||||
from torch._C._autograd import DeviceType
|
||||
from torch._C._distributed_c10d import _SymmetricMemory
|
||||
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
|
||||
|
|
@ -81,42 +82,8 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
|||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
enable_symm_mem_for_group(dist.group.WORLD.group_name)
|
||||
torch.manual_seed(42 + self.rank)
|
||||
|
||||
def _get_test_alloc_args(self):
|
||||
shape = (64, 64)
|
||||
stride = (64, 1)
|
||||
dtype = torch.float32
|
||||
device = self.device
|
||||
group_name = "0"
|
||||
return (shape, stride, dtype, device, group_name)
|
||||
|
||||
def _verify_symmetric_memory(self, symm_mem):
|
||||
self.assertEqual(symm_mem.world_size, 2)
|
||||
|
||||
buf = symm_mem.get_buffer(0, (symm_mem.buffer_size // 4,), torch.float32)
|
||||
self.assertEqual(buf.storage_offset(), 0)
|
||||
self.assertEqual(buf.untyped_storage().size(), symm_mem.buffer_size)
|
||||
|
||||
if symm_mem.rank == 0:
|
||||
symm_mem.wait_signal(src_rank=1)
|
||||
self.assertTrue(buf.eq(42).all())
|
||||
else:
|
||||
buf.fill_(42)
|
||||
symm_mem.put_signal(dst_rank=0)
|
||||
|
||||
symm_mem.barrier()
|
||||
|
||||
if symm_mem.rank == 0:
|
||||
symm_mem.barrier()
|
||||
self.assertTrue(buf.eq(43).all())
|
||||
else:
|
||||
buf.fill_(43)
|
||||
symm_mem.barrier()
|
||||
|
||||
symm_mem.barrier()
|
||||
|
||||
@skipIfRocm
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_cuda_nvlink_connectivity_detection(self) -> None:
|
||||
|
|
@ -129,10 +96,51 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
|||
for row in connectivity.matrix:
|
||||
self.assertEqual(len(row), torch.cuda.device_count())
|
||||
|
||||
@skipIfRocm
|
||||
def test_large_alloc(self) -> None:
|
||||
t = symm_mem.empty(2 * 1024**3, dtype=torch.uint8, device="cuda")
|
||||
self.assertEqual(t.numel() * t.element_size(), 2 * 1024**3)
|
||||
|
||||
def _get_test_alloc_args(self):
|
||||
shape = (64, 64)
|
||||
stride = (64, 1)
|
||||
dtype = torch.float32
|
||||
device = self.device
|
||||
group_name = "0"
|
||||
return (shape, stride, dtype, device, group_name)
|
||||
|
||||
def _verify_symmetric_memory(self, symm_mem_hdl):
|
||||
self.assertEqual(symm_mem_hdl.world_size, 2)
|
||||
|
||||
buf = symm_mem_hdl.get_buffer(
|
||||
0, (symm_mem_hdl.buffer_size // 4,), torch.float32
|
||||
)
|
||||
self.assertEqual(buf.storage_offset(), 0)
|
||||
self.assertEqual(buf.untyped_storage().size(), symm_mem_hdl.buffer_size)
|
||||
|
||||
if symm_mem_hdl.rank == 0:
|
||||
symm_mem_hdl.wait_signal(src_rank=1)
|
||||
self.assertTrue(buf.eq(42).all())
|
||||
else:
|
||||
buf.fill_(42)
|
||||
symm_mem_hdl.put_signal(dst_rank=0)
|
||||
|
||||
symm_mem_hdl.barrier()
|
||||
|
||||
if symm_mem_hdl.rank == 0:
|
||||
symm_mem_hdl.barrier()
|
||||
self.assertTrue(buf.eq(43).all())
|
||||
else:
|
||||
buf.fill_(43)
|
||||
symm_mem_hdl.barrier()
|
||||
|
||||
symm_mem_hdl.barrier()
|
||||
|
||||
@skipIfRocm
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_empty_strided_p2p(self) -> None:
|
||||
self._init_process()
|
||||
enable_symm_mem_for_group(dist.group.WORLD.group_name)
|
||||
|
||||
alloc_args = self._get_test_alloc_args()
|
||||
|
||||
|
|
@ -140,16 +148,17 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
|||
self.assertIsNone(_SymmetricMemory.rendezvous(t))
|
||||
|
||||
t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
|
||||
symm_mem = _SymmetricMemory.rendezvous(t)
|
||||
symm_mem_hdl = _SymmetricMemory.rendezvous(t)
|
||||
|
||||
del t
|
||||
self._verify_symmetric_memory(symm_mem)
|
||||
self._verify_symmetric_memory(symm_mem_hdl)
|
||||
dist.destroy_process_group()
|
||||
|
||||
@skipIfRocm
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_empty_strided_p2p_persistent(self) -> None:
|
||||
self._init_process()
|
||||
enable_symm_mem_for_group(dist.group.WORLD.group_name)
|
||||
|
||||
alloc_args = self._get_test_alloc_args()
|
||||
|
||||
|
|
@ -168,8 +177,8 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
|||
t = _SymmetricMemory.empty_strided_p2p(*alloc_args, alloc_id=42)
|
||||
self.assertEqual(t.data_ptr(), data_ptr)
|
||||
|
||||
symm_mem = _SymmetricMemory.rendezvous(t)
|
||||
self._verify_symmetric_memory(symm_mem)
|
||||
symm_mem_hdl = _SymmetricMemory.rendezvous(t)
|
||||
self._verify_symmetric_memory(symm_mem_hdl)
|
||||
dist.destroy_process_group()
|
||||
|
||||
@skipIfRocm
|
||||
|
|
@ -177,42 +186,38 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
|||
def test_get_signal_pad(self) -> None:
|
||||
self._init_process()
|
||||
|
||||
t = _SymmetricMemory.empty_strided_p2p(*self._get_test_alloc_args())
|
||||
symm_mem = _SymmetricMemory.rendezvous(t)
|
||||
t = symm_mem.empty(1, device="cuda")
|
||||
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
|
||||
peer_rank = (self.rank + 1) % self.world_size
|
||||
|
||||
signal_pad = symm_mem.get_signal_pad(self.rank)
|
||||
self.assertEqual(signal_pad.data_ptr(), symm_mem.signal_pad_ptrs[symm_mem.rank])
|
||||
signal_pad = symm_mem_hdl.get_signal_pad(self.rank)
|
||||
self.assertEqual(
|
||||
signal_pad.data_ptr(), symm_mem_hdl.signal_pad_ptrs[symm_mem_hdl.rank]
|
||||
)
|
||||
|
||||
signal_pad = symm_mem.get_signal_pad(peer_rank)
|
||||
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank)
|
||||
self.assertEqual(signal_pad.dtype, torch.uint32)
|
||||
self.assertEqual(signal_pad.numel(), symm_mem.signal_pad_size // 4)
|
||||
self.assertEqual(signal_pad.numel(), symm_mem_hdl.signal_pad_size // 4)
|
||||
|
||||
# Only specify sizes
|
||||
signal_pad = symm_mem.get_signal_pad(peer_rank, (8, 8))
|
||||
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank, (8, 8))
|
||||
self.assertEqual(signal_pad.dtype, torch.uint32)
|
||||
self.assertEqual(signal_pad.numel(), 64)
|
||||
|
||||
# Only specify dtype
|
||||
signal_pad = symm_mem.get_signal_pad(peer_rank, dtype=torch.uint64)
|
||||
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank, dtype=torch.uint64)
|
||||
self.assertEqual(signal_pad.dtype, torch.uint64)
|
||||
self.assertEqual(signal_pad.numel(), symm_mem.signal_pad_size // 8)
|
||||
self.assertEqual(signal_pad.numel(), symm_mem_hdl.signal_pad_size // 8)
|
||||
|
||||
# Specify both sizes and dtype
|
||||
signal_pad = symm_mem.get_signal_pad(peer_rank, (8, 8), dtype=torch.uint64)
|
||||
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank, (8, 8), dtype=torch.uint64)
|
||||
self.assertEqual(signal_pad.dtype, torch.uint64)
|
||||
self.assertEqual(signal_pad.numel(), 64)
|
||||
|
||||
# Sanity check that writes to buffer doesn't corrupt signal_pad
|
||||
t = _SymmetricMemory.empty_strided_p2p(
|
||||
(0,),
|
||||
(0,),
|
||||
torch.float32,
|
||||
self.device,
|
||||
dist.group.WORLD.group_name,
|
||||
)
|
||||
symm_mem = _SymmetricMemory.rendezvous(t)
|
||||
signal_pad = symm_mem.get_signal_pad(self.rank)
|
||||
t = symm_mem.empty(0, device="cuda")
|
||||
symm_mem_hdl = symm_mem.rendezvous(t)
|
||||
signal_pad = symm_mem_hdl.get_signal_pad(self.rank)
|
||||
signal_pad.fill_(42)
|
||||
t.fill_(0)
|
||||
self.assertTrue(signal_pad.eq(42).all())
|
||||
|
|
@ -224,14 +229,12 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
|||
def test_barrier_timeout(self) -> None:
|
||||
self._init_process()
|
||||
|
||||
alloc_args = self._get_test_alloc_args()
|
||||
|
||||
t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
|
||||
symm_mem = _SymmetricMemory.rendezvous(t)
|
||||
t = symm_mem.empty(1, device="cuda")
|
||||
symm_mem_hdl = _SymmetricMemory.rendezvous(t, group=dist.group.WORLD)
|
||||
|
||||
if self.rank == 0:
|
||||
with self.assertRaises(RuntimeError):
|
||||
symm_mem.barrier(timeout_ms=1000)
|
||||
symm_mem_hdl.barrier(timeout_ms=1000)
|
||||
torch.cuda.synchronize()
|
||||
else:
|
||||
torch.cuda.synchronize()
|
||||
|
|
@ -247,17 +250,15 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
|||
def test_put_signal_timeout(self) -> None:
|
||||
self._init_process()
|
||||
|
||||
alloc_args = self._get_test_alloc_args()
|
||||
|
||||
t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
|
||||
symm_mem = _SymmetricMemory.rendezvous(t)
|
||||
t = symm_mem.empty(1, device="cuda")
|
||||
symm_mem_hdl = _SymmetricMemory.rendezvous(t, group=dist.group.WORLD)
|
||||
|
||||
if self.rank == 0:
|
||||
with self.assertRaises(RuntimeError):
|
||||
# First, put a signal into rank 1's signal pad. Since rank 1
|
||||
# doesn't wait on this signal, the subsequent put will timeout.
|
||||
symm_mem.put_signal(dst_rank=1)
|
||||
symm_mem.put_signal(dst_rank=1, timeout_ms=1000)
|
||||
symm_mem_hdl.put_signal(dst_rank=1)
|
||||
symm_mem_hdl.put_signal(dst_rank=1, timeout_ms=1000)
|
||||
torch.cuda.synchronize()
|
||||
else:
|
||||
torch.cuda.synchronize()
|
||||
|
|
@ -273,14 +274,12 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
|||
def test_wait_signal_timeout(self) -> None:
|
||||
self._init_process()
|
||||
|
||||
alloc_args = self._get_test_alloc_args()
|
||||
|
||||
t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
|
||||
symm_mem = _SymmetricMemory.rendezvous(t)
|
||||
t = symm_mem.empty(1, device="cuda")
|
||||
symm_mem_hdl = _SymmetricMemory.rendezvous(t, group=dist.group.WORLD)
|
||||
|
||||
if self.rank == 0:
|
||||
with self.assertRaises(RuntimeError):
|
||||
symm_mem.wait_signal(src_rank=1, timeout_ms=1000)
|
||||
symm_mem_hdl.wait_signal(src_rank=1, timeout_ms=1000)
|
||||
torch.cuda.synchronize()
|
||||
else:
|
||||
torch.cuda.synchronize()
|
||||
|
|
@ -685,7 +684,6 @@ class SubgroupTest(MultiProcessTestCase):
|
|||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
enable_symm_mem_for_group(dist.group.WORLD.group_name)
|
||||
torch.manual_seed(42 + self.rank)
|
||||
|
||||
@skipIfRocm
|
||||
|
|
@ -699,18 +697,10 @@ class SubgroupTest(MultiProcessTestCase):
|
|||
|
||||
world = dist.group.WORLD
|
||||
subgroup = subgroup_0 if world.rank() < world.size() // 2 else subgroup_1
|
||||
enable_symm_mem_for_group(subgroup.group_name)
|
||||
|
||||
t = _SymmetricMemory.empty_strided_p2p(
|
||||
size=(64,),
|
||||
stride=(1,),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
symm_mem_world = _SymmetricMemory.rendezvous(t, group_name=world.group_name)
|
||||
symm_mem_subgroup = _SymmetricMemory.rendezvous(
|
||||
t, group_name=subgroup.group_name
|
||||
)
|
||||
t = symm_mem.empty(64, device="cuda")
|
||||
symm_mem_world = symm_mem.rendezvous(t, group=world)
|
||||
symm_mem_subgroup = symm_mem.rendezvous(t, group=subgroup)
|
||||
|
||||
self.assertEqual(symm_mem_world.world_size, world.size())
|
||||
self.assertEqual(symm_mem_world.rank, world.rank())
|
||||
|
|
|
|||
|
|
@ -110,6 +110,8 @@ def get_symm_mem_workspace(group_name: str, min_size: int) -> _SymmetricMemory:
|
|||
_SymmetricMemory: the symmetric memory workspace associated with the
|
||||
group.
|
||||
"""
|
||||
enable_symm_mem_for_group(group_name)
|
||||
|
||||
tensor = _group_name_to_workspace_tensor.get(group_name)
|
||||
size = tensor.numel() * tensor.element_size() if tensor is not None else 0
|
||||
if tensor is None or size < min_size:
|
||||
|
|
@ -1386,3 +1388,109 @@ def _low_contention_reduce_scatter(
|
|||
return _low_contention_reduce_scatter_with_workspace(
|
||||
tensor, reduce_op, workspace
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# User-facing APIs
|
||||
# =============================================================================
|
||||
|
||||
|
||||
from typing import Any, overload, Sequence, TYPE_CHECKING, Union
|
||||
|
||||
from torch.types import _device, _dtype, _int
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._C._distributed_c10d import ProcessGroup
|
||||
|
||||
|
||||
@overload
|
||||
def empty(
|
||||
*size: _int, dtype: Optional[_dtype] = None, device: Optional[_device] = None
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def empty(
|
||||
size: Sequence[_int],
|
||||
*,
|
||||
dtype: Optional[_dtype] = None,
|
||||
device: Optional[_device] = None,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
|
||||
|
||||
def empty( # type: ignore[misc]
|
||||
*size: Any,
|
||||
dtype: Optional[_dtype] = None,
|
||||
device: Optional[_device] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
empty(*size, *, dtype=None, device=None) -> Tensor
|
||||
|
||||
Similar to :func:`torch.empty()`. The returned tensor can be used by
|
||||
:func:`torch._distributed._symmetric_memory.rendezvous()` to establish a
|
||||
symmetric memory tensor among participating processes.
|
||||
|
||||
Args:
|
||||
size (int...): a sequence of integers defining the shape of the output tensor.
|
||||
Can be a variable number of arguments or a collection like a list or tuple.
|
||||
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
|
||||
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
|
||||
device (:class:`torch.device`, optional): the desired device of returned tensor.
|
||||
Default: if ``None``, uses the current device for the default tensor type
|
||||
(see :func:`torch.set_default_device`). :attr:`device` will be the CPU
|
||||
for CPU tensor types and the current CUDA device for CUDA tensor types.
|
||||
"""
|
||||
if len(size) == 1 and isinstance(size[0], Sequence):
|
||||
size = tuple(size[0])
|
||||
else:
|
||||
size = tuple(size)
|
||||
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
if device is None:
|
||||
device = torch.get_default_device()
|
||||
|
||||
return _SymmetricMemory.empty_strided_p2p(
|
||||
size=size,
|
||||
stride=torch._prims_common.make_contiguous_strides_for(size),
|
||||
dtype=dtype,
|
||||
device=torch.device(device),
|
||||
)
|
||||
|
||||
|
||||
def rendezvous(
|
||||
tensor: torch.Tensor, group: Union[str, "ProcessGroup"]
|
||||
) -> _SymmetricMemory:
|
||||
r"""
|
||||
rendezvous(tensor, group) -> _SymmetricMemory
|
||||
|
||||
Establish a symmetric memory tensor among participating processes. This is
|
||||
a collective operation.
|
||||
|
||||
Args:
|
||||
tensor (:class:`torch.Tensor`): the local tensor used to establish the symmetric memory tensor.
|
||||
It must be allocated via :func:`torch._distributed._symmetric_memory.empty()`. The shape,
|
||||
dtype, and device type must be identical across all participating processes.
|
||||
group (Union[str, :class:`torch.distributed.ProcessGroup`]): The group identifying the
|
||||
participating processes. This can be either a group name or a process group object.
|
||||
"""
|
||||
from torch._C._distributed_c10d import ProcessGroup
|
||||
|
||||
if isinstance(group, str):
|
||||
group_name = group
|
||||
elif isinstance(group, ProcessGroup):
|
||||
group_name = group.group_name
|
||||
else:
|
||||
raise TypeError(f"rendezvous: unsupported group type: {type(group)}")
|
||||
|
||||
enable_symm_mem_for_group(group_name)
|
||||
return _SymmetricMemory.rendezvous(tensor, group_name)
|
||||
|
||||
|
||||
__all__ = ["empty", "rendezvous"]
|
||||
|
|
|
|||
Loading…
Reference in a new issue