[SymmetricMemory] introduce user-facing APIs empty() and rendezvous() (#139677)

Previously `SymmetricMemory` only had private pybind APIs:
```python
from torch.distributed._symmetric_memory import _SymmetricMemory
t = _SymmetricMemory.empty_strided_p2p(
    size=(64,),
    stride=(1,),
    dtype=torch.float32,
    device=device,
)
symm_mem_hdl = _SymmetricMemory.rendezvous(t, group_name=group.group_name)
```

This PR introduces user-facing APIs empty() and rendezvous():
```python
import torch.distributed._symmetric_memory as symm_mem
t = symm_mem.empty(64, device="cuda")
symm_mem_hdl = symm_mem.rendezvous(t, group_name=group.group_name)
```

Notable differences compared to the pybind APIs:
- `empty()` now resembles `torch.empty()`:
  - shape can either be an integer sequence or pack
  - no need to/can't specify stride anymore
  - device can either be `torch.device` or string
- `group_name` needs to be specified at rendezvous time as opposed to allocation time. See https://github.com/pytorch/pytorch/pull/139529 for the rationales. I feel the new semantic is superior, hence enforcing it in the public API.
  - Currently, the pybind API still support specifying `group_name` at rendezvous time.

This PR does not change the behavior of the pybind APIs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139677
Approved by: https://github.com/lw
ghstack dependencies: #139529
This commit is contained in:
Yifu Wang 2024-11-16 20:36:51 -08:00 committed by PyTorch MergeBot
parent 9f4af6b4e6
commit 5a7e147ef3
2 changed files with 183 additions and 85 deletions

View file

@ -5,6 +5,7 @@ from unittest import skipIf
import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
from torch._C._autograd import DeviceType
from torch._C._distributed_c10d import _SymmetricMemory
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
@ -81,42 +82,8 @@ class SymmetricMemoryTest(MultiProcessTestCase):
rank=self.rank,
store=store,
)
enable_symm_mem_for_group(dist.group.WORLD.group_name)
torch.manual_seed(42 + self.rank)
def _get_test_alloc_args(self):
shape = (64, 64)
stride = (64, 1)
dtype = torch.float32
device = self.device
group_name = "0"
return (shape, stride, dtype, device, group_name)
def _verify_symmetric_memory(self, symm_mem):
self.assertEqual(symm_mem.world_size, 2)
buf = symm_mem.get_buffer(0, (symm_mem.buffer_size // 4,), torch.float32)
self.assertEqual(buf.storage_offset(), 0)
self.assertEqual(buf.untyped_storage().size(), symm_mem.buffer_size)
if symm_mem.rank == 0:
symm_mem.wait_signal(src_rank=1)
self.assertTrue(buf.eq(42).all())
else:
buf.fill_(42)
symm_mem.put_signal(dst_rank=0)
symm_mem.barrier()
if symm_mem.rank == 0:
symm_mem.barrier()
self.assertTrue(buf.eq(43).all())
else:
buf.fill_(43)
symm_mem.barrier()
symm_mem.barrier()
@skipIfRocm
@skip_if_lt_x_gpu(2)
def test_cuda_nvlink_connectivity_detection(self) -> None:
@ -129,10 +96,51 @@ class SymmetricMemoryTest(MultiProcessTestCase):
for row in connectivity.matrix:
self.assertEqual(len(row), torch.cuda.device_count())
@skipIfRocm
def test_large_alloc(self) -> None:
t = symm_mem.empty(2 * 1024**3, dtype=torch.uint8, device="cuda")
self.assertEqual(t.numel() * t.element_size(), 2 * 1024**3)
def _get_test_alloc_args(self):
shape = (64, 64)
stride = (64, 1)
dtype = torch.float32
device = self.device
group_name = "0"
return (shape, stride, dtype, device, group_name)
def _verify_symmetric_memory(self, symm_mem_hdl):
self.assertEqual(symm_mem_hdl.world_size, 2)
buf = symm_mem_hdl.get_buffer(
0, (symm_mem_hdl.buffer_size // 4,), torch.float32
)
self.assertEqual(buf.storage_offset(), 0)
self.assertEqual(buf.untyped_storage().size(), symm_mem_hdl.buffer_size)
if symm_mem_hdl.rank == 0:
symm_mem_hdl.wait_signal(src_rank=1)
self.assertTrue(buf.eq(42).all())
else:
buf.fill_(42)
symm_mem_hdl.put_signal(dst_rank=0)
symm_mem_hdl.barrier()
if symm_mem_hdl.rank == 0:
symm_mem_hdl.barrier()
self.assertTrue(buf.eq(43).all())
else:
buf.fill_(43)
symm_mem_hdl.barrier()
symm_mem_hdl.barrier()
@skipIfRocm
@skip_if_lt_x_gpu(2)
def test_empty_strided_p2p(self) -> None:
self._init_process()
enable_symm_mem_for_group(dist.group.WORLD.group_name)
alloc_args = self._get_test_alloc_args()
@ -140,16 +148,17 @@ class SymmetricMemoryTest(MultiProcessTestCase):
self.assertIsNone(_SymmetricMemory.rendezvous(t))
t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
symm_mem = _SymmetricMemory.rendezvous(t)
symm_mem_hdl = _SymmetricMemory.rendezvous(t)
del t
self._verify_symmetric_memory(symm_mem)
self._verify_symmetric_memory(symm_mem_hdl)
dist.destroy_process_group()
@skipIfRocm
@skip_if_lt_x_gpu(2)
def test_empty_strided_p2p_persistent(self) -> None:
self._init_process()
enable_symm_mem_for_group(dist.group.WORLD.group_name)
alloc_args = self._get_test_alloc_args()
@ -168,8 +177,8 @@ class SymmetricMemoryTest(MultiProcessTestCase):
t = _SymmetricMemory.empty_strided_p2p(*alloc_args, alloc_id=42)
self.assertEqual(t.data_ptr(), data_ptr)
symm_mem = _SymmetricMemory.rendezvous(t)
self._verify_symmetric_memory(symm_mem)
symm_mem_hdl = _SymmetricMemory.rendezvous(t)
self._verify_symmetric_memory(symm_mem_hdl)
dist.destroy_process_group()
@skipIfRocm
@ -177,42 +186,38 @@ class SymmetricMemoryTest(MultiProcessTestCase):
def test_get_signal_pad(self) -> None:
self._init_process()
t = _SymmetricMemory.empty_strided_p2p(*self._get_test_alloc_args())
symm_mem = _SymmetricMemory.rendezvous(t)
t = symm_mem.empty(1, device="cuda")
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
peer_rank = (self.rank + 1) % self.world_size
signal_pad = symm_mem.get_signal_pad(self.rank)
self.assertEqual(signal_pad.data_ptr(), symm_mem.signal_pad_ptrs[symm_mem.rank])
signal_pad = symm_mem_hdl.get_signal_pad(self.rank)
self.assertEqual(
signal_pad.data_ptr(), symm_mem_hdl.signal_pad_ptrs[symm_mem_hdl.rank]
)
signal_pad = symm_mem.get_signal_pad(peer_rank)
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank)
self.assertEqual(signal_pad.dtype, torch.uint32)
self.assertEqual(signal_pad.numel(), symm_mem.signal_pad_size // 4)
self.assertEqual(signal_pad.numel(), symm_mem_hdl.signal_pad_size // 4)
# Only specify sizes
signal_pad = symm_mem.get_signal_pad(peer_rank, (8, 8))
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank, (8, 8))
self.assertEqual(signal_pad.dtype, torch.uint32)
self.assertEqual(signal_pad.numel(), 64)
# Only specify dtype
signal_pad = symm_mem.get_signal_pad(peer_rank, dtype=torch.uint64)
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank, dtype=torch.uint64)
self.assertEqual(signal_pad.dtype, torch.uint64)
self.assertEqual(signal_pad.numel(), symm_mem.signal_pad_size // 8)
self.assertEqual(signal_pad.numel(), symm_mem_hdl.signal_pad_size // 8)
# Specify both sizes and dtype
signal_pad = symm_mem.get_signal_pad(peer_rank, (8, 8), dtype=torch.uint64)
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank, (8, 8), dtype=torch.uint64)
self.assertEqual(signal_pad.dtype, torch.uint64)
self.assertEqual(signal_pad.numel(), 64)
# Sanity check that writes to buffer doesn't corrupt signal_pad
t = _SymmetricMemory.empty_strided_p2p(
(0,),
(0,),
torch.float32,
self.device,
dist.group.WORLD.group_name,
)
symm_mem = _SymmetricMemory.rendezvous(t)
signal_pad = symm_mem.get_signal_pad(self.rank)
t = symm_mem.empty(0, device="cuda")
symm_mem_hdl = symm_mem.rendezvous(t)
signal_pad = symm_mem_hdl.get_signal_pad(self.rank)
signal_pad.fill_(42)
t.fill_(0)
self.assertTrue(signal_pad.eq(42).all())
@ -224,14 +229,12 @@ class SymmetricMemoryTest(MultiProcessTestCase):
def test_barrier_timeout(self) -> None:
self._init_process()
alloc_args = self._get_test_alloc_args()
t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
symm_mem = _SymmetricMemory.rendezvous(t)
t = symm_mem.empty(1, device="cuda")
symm_mem_hdl = _SymmetricMemory.rendezvous(t, group=dist.group.WORLD)
if self.rank == 0:
with self.assertRaises(RuntimeError):
symm_mem.barrier(timeout_ms=1000)
symm_mem_hdl.barrier(timeout_ms=1000)
torch.cuda.synchronize()
else:
torch.cuda.synchronize()
@ -247,17 +250,15 @@ class SymmetricMemoryTest(MultiProcessTestCase):
def test_put_signal_timeout(self) -> None:
self._init_process()
alloc_args = self._get_test_alloc_args()
t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
symm_mem = _SymmetricMemory.rendezvous(t)
t = symm_mem.empty(1, device="cuda")
symm_mem_hdl = _SymmetricMemory.rendezvous(t, group=dist.group.WORLD)
if self.rank == 0:
with self.assertRaises(RuntimeError):
# First, put a signal into rank 1's signal pad. Since rank 1
# doesn't wait on this signal, the subsequent put will timeout.
symm_mem.put_signal(dst_rank=1)
symm_mem.put_signal(dst_rank=1, timeout_ms=1000)
symm_mem_hdl.put_signal(dst_rank=1)
symm_mem_hdl.put_signal(dst_rank=1, timeout_ms=1000)
torch.cuda.synchronize()
else:
torch.cuda.synchronize()
@ -273,14 +274,12 @@ class SymmetricMemoryTest(MultiProcessTestCase):
def test_wait_signal_timeout(self) -> None:
self._init_process()
alloc_args = self._get_test_alloc_args()
t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
symm_mem = _SymmetricMemory.rendezvous(t)
t = symm_mem.empty(1, device="cuda")
symm_mem_hdl = _SymmetricMemory.rendezvous(t, group=dist.group.WORLD)
if self.rank == 0:
with self.assertRaises(RuntimeError):
symm_mem.wait_signal(src_rank=1, timeout_ms=1000)
symm_mem_hdl.wait_signal(src_rank=1, timeout_ms=1000)
torch.cuda.synchronize()
else:
torch.cuda.synchronize()
@ -685,7 +684,6 @@ class SubgroupTest(MultiProcessTestCase):
rank=self.rank,
store=store,
)
enable_symm_mem_for_group(dist.group.WORLD.group_name)
torch.manual_seed(42 + self.rank)
@skipIfRocm
@ -699,18 +697,10 @@ class SubgroupTest(MultiProcessTestCase):
world = dist.group.WORLD
subgroup = subgroup_0 if world.rank() < world.size() // 2 else subgroup_1
enable_symm_mem_for_group(subgroup.group_name)
t = _SymmetricMemory.empty_strided_p2p(
size=(64,),
stride=(1,),
dtype=torch.float32,
device=self.device,
)
symm_mem_world = _SymmetricMemory.rendezvous(t, group_name=world.group_name)
symm_mem_subgroup = _SymmetricMemory.rendezvous(
t, group_name=subgroup.group_name
)
t = symm_mem.empty(64, device="cuda")
symm_mem_world = symm_mem.rendezvous(t, group=world)
symm_mem_subgroup = symm_mem.rendezvous(t, group=subgroup)
self.assertEqual(symm_mem_world.world_size, world.size())
self.assertEqual(symm_mem_world.rank, world.rank())

View file

@ -110,6 +110,8 @@ def get_symm_mem_workspace(group_name: str, min_size: int) -> _SymmetricMemory:
_SymmetricMemory: the symmetric memory workspace associated with the
group.
"""
enable_symm_mem_for_group(group_name)
tensor = _group_name_to_workspace_tensor.get(group_name)
size = tensor.numel() * tensor.element_size() if tensor is not None else 0
if tensor is None or size < min_size:
@ -1386,3 +1388,109 @@ def _low_contention_reduce_scatter(
return _low_contention_reduce_scatter_with_workspace(
tensor, reduce_op, workspace
)
# =============================================================================
# User-facing APIs
# =============================================================================
from typing import Any, overload, Sequence, TYPE_CHECKING, Union
from torch.types import _device, _dtype, _int
if TYPE_CHECKING:
from torch._C._distributed_c10d import ProcessGroup
@overload
def empty(
*size: _int, dtype: Optional[_dtype] = None, device: Optional[_device] = None
) -> torch.Tensor:
...
@overload
def empty(
size: Sequence[_int],
*,
dtype: Optional[_dtype] = None,
device: Optional[_device] = None,
) -> torch.Tensor:
...
def empty( # type: ignore[misc]
*size: Any,
dtype: Optional[_dtype] = None,
device: Optional[_device] = None,
) -> torch.Tensor:
r"""
empty(*size, *, dtype=None, device=None) -> Tensor
Similar to :func:`torch.empty()`. The returned tensor can be used by
:func:`torch._distributed._symmetric_memory.rendezvous()` to establish a
symmetric memory tensor among participating processes.
Args:
size (int...): a sequence of integers defining the shape of the output tensor.
Can be a variable number of arguments or a collection like a list or tuple.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
device (:class:`torch.device`, optional): the desired device of returned tensor.
Default: if ``None``, uses the current device for the default tensor type
(see :func:`torch.set_default_device`). :attr:`device` will be the CPU
for CPU tensor types and the current CUDA device for CUDA tensor types.
"""
if len(size) == 1 and isinstance(size[0], Sequence):
size = tuple(size[0])
else:
size = tuple(size)
if dtype is None:
dtype = torch.get_default_dtype()
if device is None:
device = torch.get_default_device()
return _SymmetricMemory.empty_strided_p2p(
size=size,
stride=torch._prims_common.make_contiguous_strides_for(size),
dtype=dtype,
device=torch.device(device),
)
def rendezvous(
tensor: torch.Tensor, group: Union[str, "ProcessGroup"]
) -> _SymmetricMemory:
r"""
rendezvous(tensor, group) -> _SymmetricMemory
Establish a symmetric memory tensor among participating processes. This is
a collective operation.
Args:
tensor (:class:`torch.Tensor`): the local tensor used to establish the symmetric memory tensor.
It must be allocated via :func:`torch._distributed._symmetric_memory.empty()`. The shape,
dtype, and device type must be identical across all participating processes.
group (Union[str, :class:`torch.distributed.ProcessGroup`]): The group identifying the
participating processes. This can be either a group name or a process group object.
"""
from torch._C._distributed_c10d import ProcessGroup
if isinstance(group, str):
group_name = group
elif isinstance(group, ProcessGroup):
group_name = group.group_name
else:
raise TypeError(f"rendezvous: unsupported group type: {type(group)}")
enable_symm_mem_for_group(group_name)
return _SymmetricMemory.rendezvous(tensor, group_name)
__all__ = ["empty", "rendezvous"]