mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): This PR introduces a prototype for `SymmetricMemory` (including a CUDA implementation) - a remote-memory access-based communication primitive. It allows for user-defined communication patterns/kernels and is designed to be torch.compile-friendly. It addresses the major limitations of `IntraNodeComm` and `ProcessGroupCudaP2p` and serves as a replacement for them. ### SymmetricMemory `SymmetricMemory` represents symmetric allocations across a group of devices. The allocations represented by a `SymmetricMemory` object are accessible by all devices in the group. The class can be used for **op-level custom communication patterns** (via the get_buffer APIs and the synchronization primitives), as well as **custom communication kernels** (via the buffer and signal_pad device pointers). ### Python API Example ```python from torch._C.distributed_c10d import _SymmetricMemory # Set a store for rendezvousing symmetric allocations on a group of devices # identified by group_name. The concept of groups is logical; users can # utilize predefined groups (e.g., a group of device identified by a # ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator # backends might employ a more efficient communication channel for the actual # rendezvous process and only use the store for bootstrapping purposes. _SymmetricMemory.set_group_info(group_name, rank, world_size, store) # Identical to empty_strided, but allows symmetric memory access to be # established for the allocated tensor via _SymmetricMemory.rendezvous(). # This function itself is not a collective operation. t = _SymmetricMemory.empty_strided_p2p((64, 64), (64, 1), torch.float32, group_name) # Users can write Python custom ops that leverages the symmetric memory access. # Below are examples of things users can do (assuming the group's world_size is 2). # Establishes symmetric memory access on tensors allocated via # _SymmetricMemory.empty_strided_p2p(). rendezvous() is a one-time process, # and the mapping between a local memory region and the associated SymmetricMemory # object is unique. Subsequent calls to rendezvous() with the same tensor will receive # the cached SymmetricMemory object. # # The function has a collective semantic and must be invoked simultaneously # from all rendezvous participants. symm_mem = _SymmetricMemory.rendezvous(t) # This represents the allocation on rank 0 and is accessible from all devices. buf = symm_mem.get_buffer(0, (64, 64), torch.float32) if symm_mem.rank == 0: symm_mem.wait_signal(src_rank=1) assert buf.eq(42).all() else: # The remote buffer can be used as a regular tensor buf.fill_(42) symm_mem.put_signal(dst_rank=0) symm_mem.barrier() if symm_mem.rank == 0: symm_mem.barrier() assert buf.eq(43).all() else: new_val = torch.empty_like(buf) new_val.fill_(43) # Contiguous copies to/from a remote buffer utilize copy engines # which bypasses SMs (i.e. no need to load the data into registers) buf.copy_(new_val) symm_mem.barrier() ``` ### Custom CUDA Comm Kernels Given a tensor, users can access the associated `SymmetricMemory` which provides pointer to remote buffers/signal_pads needed for custom communication kernels. ```cpp TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory( const at::Tensor& tensor); class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { public: ... virtual std::vector<void*> get_buffer_ptrs() = 0; virtual std::vector<void*> get_signal_pad_ptrs() = 0; virtual void** get_buffer_ptrs_dev() = 0; virtual void** get_signal_pad_ptrs_dev() = 0; virtual size_t get_buffer_size() = 0; virtual size_t get_signal_pad_size() = 0; virtual int get_rank() = 0; virtual int get_world_size() = 0; ... }; ``` ### Limitations of IntraNodeComm and ProcessGroupCudaP2p Both `IntraNodeComm` (used by `ProcessGroupCudaP2p`) manages a single fixed-size workspace. This approach: - Leads to awkward UX in which the required workspace needs to be specified upfront. - Can not avoid extra copies for some algorithms in eager mode (e.g., custom/multimem all-reduce, reduce-scatter, all-gather). - Prevents torch.compile from eliminating all copies. In addition, they only offer out-of-the-box communication kernels and don't expose required pointers for user-defined, custom CUDA comm kernels. * __->__ #128582 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128582 Approved by: https://github.com/wanchaol
156 lines
4.5 KiB
Python
156 lines
4.5 KiB
Python
# Owner(s): ["module: c10d"]
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
from torch._C._distributed_c10d import _SymmetricMemory
|
|
from torch.distributed.distributed_c10d import _get_process_group_store
|
|
|
|
from torch.testing._internal.common_distributed import (
|
|
MultiProcessTestCase,
|
|
skip_if_lt_x_gpu,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
run_tests,
|
|
skip_but_pass_in_sandcastle_if,
|
|
skipIfRocm,
|
|
)
|
|
|
|
|
|
def requires_cuda_p2p_access():
|
|
cuda_p2p_access_available = (
|
|
torch.cuda.is_available() and torch.cuda.device_count() >= 2
|
|
)
|
|
num_devices = torch.cuda.device_count()
|
|
for i in range(num_devices - 1):
|
|
for j in range(i + 1, num_devices):
|
|
if not torch.cuda.can_device_access_peer(i, j):
|
|
cuda_p2p_access_available = False
|
|
break
|
|
if not cuda_p2p_access_available:
|
|
break
|
|
|
|
return skip_but_pass_in_sandcastle_if(
|
|
not cuda_p2p_access_available,
|
|
"cuda p2p access is not available",
|
|
)
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
@requires_cuda_p2p_access()
|
|
class SymmetricMemoryTest(MultiProcessTestCase):
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
self._spawn_processes()
|
|
|
|
@property
|
|
def world_size(self) -> int:
|
|
return 2
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return torch.device(f"cuda:{self.rank}")
|
|
|
|
def _init_process(self):
|
|
torch.cuda.set_device(self.device)
|
|
store = dist.FileStore(self.file_name, self.world_size)
|
|
dist.init_process_group(
|
|
backend="nccl",
|
|
world_size=self.world_size,
|
|
rank=self.rank,
|
|
store=store,
|
|
)
|
|
_SymmetricMemory.set_group_info(
|
|
"0",
|
|
self.rank,
|
|
self.world_size,
|
|
_get_process_group_store(dist.GroupMember.WORLD),
|
|
)
|
|
|
|
def _verify_symmetric_memory(self, symm_mem):
|
|
self.assertEqual(symm_mem.world_size, 2)
|
|
|
|
buf = symm_mem.get_buffer(0, (64, 64), torch.float32)
|
|
if symm_mem.rank == 0:
|
|
symm_mem.wait_signal(src_rank=1)
|
|
self.assertTrue(buf.eq(42).all())
|
|
else:
|
|
buf.fill_(42)
|
|
symm_mem.put_signal(dst_rank=0)
|
|
|
|
symm_mem.barrier()
|
|
|
|
if symm_mem.rank == 0:
|
|
symm_mem.barrier()
|
|
self.assertTrue(buf.eq(43).all())
|
|
else:
|
|
buf.fill_(43)
|
|
symm_mem.barrier()
|
|
|
|
symm_mem.barrier()
|
|
|
|
@skipIfRocm
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_empty_strided_p2p(self) -> None:
|
|
self._init_process()
|
|
|
|
shape = (64, 64)
|
|
stride = (64, 1)
|
|
dtype = torch.float32
|
|
device = self.device
|
|
group_name = "0"
|
|
alloc_args = (shape, stride, dtype, device, group_name)
|
|
|
|
t = torch.empty(shape, dtype=dtype, device=device)
|
|
with self.assertRaises(RuntimeError):
|
|
_SymmetricMemory.rendezvous(t)
|
|
|
|
t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
|
|
symm_mem = _SymmetricMemory.rendezvous(t)
|
|
|
|
del t
|
|
self._verify_symmetric_memory(symm_mem)
|
|
|
|
@skipIfRocm
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_empty_strided_p2p_persistent(self) -> None:
|
|
self._init_process()
|
|
|
|
shape = (64, 64)
|
|
stride = (64, 1)
|
|
dtype = torch.float32
|
|
device = self.device
|
|
alloc_id = 42 # Persistent allocation
|
|
group_name = "0"
|
|
alloc_args = (shape, stride, dtype, device, group_name, alloc_id)
|
|
|
|
t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
|
|
data_ptr = t.data_ptr()
|
|
|
|
# Verify that persistent allocation would fail if there's an active
|
|
# allocation with the same alloc_id.
|
|
with self.assertRaises(RuntimeError):
|
|
_SymmetricMemory.empty_strided_p2p(*alloc_args)
|
|
|
|
# Verify that persistent allocation would succeed in lieu of activate
|
|
# allocations with the same alloc_id, and the returned tensor would
|
|
# have the same data pointer.
|
|
del t
|
|
t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
|
|
self.assertEqual(t.data_ptr(), data_ptr)
|
|
|
|
# Verify that get_symmetric_memory would fail if called before
|
|
# rendezvous.
|
|
with self.assertRaises(RuntimeError):
|
|
_SymmetricMemory.get_symmetric_memory(t)
|
|
|
|
symm_mem_0 = _SymmetricMemory.rendezvous(t)
|
|
symm_mem_1 = _SymmetricMemory.get_symmetric_memory(t)
|
|
self.assertEqual(id(symm_mem_0), id(symm_mem_1))
|
|
|
|
self._verify_symmetric_memory(symm_mem_0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|