From ab5c8857eff71e9a2e2f8b074034925249ccc5b4 Mon Sep 17 00:00:00 2001 From: Yifu Wang Date: Sat, 16 Nov 2024 20:36:51 -0800 Subject: [PATCH] [SymmetricMemory] support specifying group_name at rendezvous time (#139529) Before this PR, users need to call `empty_strided_p2p()` with a `group_name`: ```python tensor = _SymmetricMemory.empty_strided_p2p((1024,), (1,), device=device, group_name="0") symm_mem = _SymmetricMemory.rendezvous(tensor) ``` Users can now omit `group_name` at allocation time and specify it later at rendezvous time: ```python tensor = _SymmetricMemory.empty_strided_p2p((1024,), (1,), device=device) symm_mem = _SymmetricMemory.rendezvous(tensor, group_name="0") ``` Rationales for this change: - This allows the same allocation to establish symmetric memory under different groups - Specifying `group_name` at rendezvous time instead of allocation time is a more natural UX Pull Request resolved: https://github.com/pytorch/pytorch/pull/139529 Approved by: https://github.com/lw --- test/distributed/test_symmetric_memory.py | 87 +++++++-- torch/_C/_distributed_c10d.pyi | 7 +- .../distributed/c10d/CUDASymmetricMemory.cu | 169 ++++++++++-------- .../distributed/c10d/CUDASymmetricMemory.hpp | 56 +++--- .../csrc/distributed/c10d/SymmetricMemory.cpp | 19 +- .../csrc/distributed/c10d/SymmetricMemory.hpp | 17 +- torch/csrc/distributed/c10d/init.cpp | 9 +- .../csrc/distributed/c10d/intra_node_comm.cpp | 4 +- 8 files changed, 225 insertions(+), 143 deletions(-) diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py index c32f0d36610..bbe229ba098 100644 --- a/test/distributed/test_symmetric_memory.py +++ b/test/distributed/test_symmetric_memory.py @@ -97,7 +97,7 @@ class SymmetricMemoryTest(MultiProcessTestCase): buf = symm_mem.get_buffer(0, (symm_mem.buffer_size // 4,), torch.float32) self.assertEqual(buf.storage_offset(), 0) - self.assertEqual(buf.storage().size(), symm_mem.buffer_size // 4) + self.assertEqual(buf.untyped_storage().size(), symm_mem.buffer_size) if symm_mem.rank == 0: symm_mem.wait_signal(src_rank=1) @@ -168,16 +168,8 @@ class SymmetricMemoryTest(MultiProcessTestCase): t = _SymmetricMemory.empty_strided_p2p(*alloc_args, alloc_id=42) self.assertEqual(t.data_ptr(), data_ptr) - # Verify that get_symmetric_memory would fail if called before - # rendezvous. - with self.assertRaises(RuntimeError): - _SymmetricMemory.get_symmetric_memory(t) - - symm_mem_0 = _SymmetricMemory.rendezvous(t) - symm_mem_1 = _SymmetricMemory.get_symmetric_memory(t) - self.assertEqual(id(symm_mem_0), id(symm_mem_1)) - - self._verify_symmetric_memory(symm_mem_0) + symm_mem = _SymmetricMemory.rendezvous(t) + self._verify_symmetric_memory(symm_mem) dist.destroy_process_group() @skipIfRocm @@ -669,6 +661,79 @@ class SymmetricMemoryTest(MultiProcessTestCase): dist.destroy_process_group() +@instantiate_parametrized_tests +@requires_cuda_p2p_access() +class SubgroupTest(MultiProcessTestCase): + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + @property + def world_size(self) -> int: + return 4 + + @property + def device(self) -> torch.device: + return torch.device(f"cuda:{self.rank}") + + def _init_process(self): + torch.cuda.set_device(self.device) + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + enable_symm_mem_for_group(dist.group.WORLD.group_name) + torch.manual_seed(42 + self.rank) + + @skipIfRocm + @skip_if_lt_x_gpu(4) + def test_subgroup(self) -> None: + self._init_process() + + ranks = list(range(self.world_size)) + subgroup_0 = dist.new_group(ranks[: len(ranks) // 2]) + subgroup_1 = dist.new_group(ranks[len(ranks) // 2 :]) + + world = dist.group.WORLD + subgroup = subgroup_0 if world.rank() < world.size() // 2 else subgroup_1 + enable_symm_mem_for_group(subgroup.group_name) + + t = _SymmetricMemory.empty_strided_p2p( + size=(64,), + stride=(1,), + dtype=torch.float32, + device=self.device, + ) + symm_mem_world = _SymmetricMemory.rendezvous(t, group_name=world.group_name) + symm_mem_subgroup = _SymmetricMemory.rendezvous( + t, group_name=subgroup.group_name + ) + + self.assertEqual(symm_mem_world.world_size, world.size()) + self.assertEqual(symm_mem_world.rank, world.rank()) + self.assertEqual(symm_mem_subgroup.world_size, world.size() // 2) + self.assertEqual(symm_mem_subgroup.rank, world.rank() % subgroup.size()) + + t.fill_(world.rank()) + symm_mem_world.barrier() + + # Observe a peer buffer via the world group + peer_rank = (world.rank() + 1) % world.size() + buf = symm_mem_world.get_buffer(peer_rank, (64,), torch.float32) + self.assertTrue(buf.eq(peer_rank).all()) + + # Observe a peer buffer via the subgroup + peer_rank = (subgroup.rank() + 1) % subgroup.size() + buf = symm_mem_subgroup.get_buffer(peer_rank, (64,), torch.float32) + if world.rank() < world.size() // 2: + self.assertTrue(buf.eq(peer_rank).all()) + else: + self.assertTrue(buf.eq(peer_rank + world.size() // 2).all()) + + @instantiate_parametrized_tests @requires_cuda_p2p_access() class SymmMemAllReduceTest(MultiProcessTestCase): diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index e8102a766c7..bdc58946fd8 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -662,14 +662,17 @@ class _SymmetricMemory: stride: torch.types._size, dtype: torch.dtype, device: torch.device, - group_name: str, + group_name: str | None = None, + alloc_id: int | None = None, ) -> torch.Tensor: ... @property def rank(self) -> int: ... @property def world_size(self) -> int: ... @staticmethod - def rendezvous(tensor: torch.Tensor) -> _SymmetricMemory: ... + def rendezvous( + tensor: torch.Tensor, group_name: str | None = None + ) -> _SymmetricMemory: ... def get_buffer( self, rank: int, diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu index 5af394cff52..d5e65eac975 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu @@ -95,7 +95,6 @@ class IpcChannel { cmsg->cmsg_type = SCM_RIGHTS; if (fd != -1) { - // memcpy(CMSG_DATA(cmsg), &fd, sizeof(fd)); std::copy( reinterpret_cast(&fd), reinterpret_cast(&fd) + sizeof(fd), @@ -273,9 +272,29 @@ void map_block( namespace c10d { namespace symmetric_memory { +AllocationRef::AllocationRef(void* ptr, HandleType handle, size_t block_size, int device_idx) + : ptr(ptr), handle(handle), block_size(block_size), device_idx(device_idx) {} + +AllocationRef::~AllocationRef() { +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + // Leak the cuda allocations during static deinitialization + if (is_finalizing()) { + return; + } + auto driver_api = c10::cuda::DriverAPI::get(); + c10::cuda::CUDAGuard guard(device_idx); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_DRIVER_CHECK( + driver_api->cuMemUnmap_(reinterpret_cast(ptr), block_size)); + C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handle)); +#else + TORCH_CHECK( + false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); +#endif +} + CUDASymmetricMemory::CUDASymmetricMemory( - std::vector handles, - size_t block_size, + std::vector> alloc_refs, std::vector buffers, std::vector signal_pads, HandleType mc_handle, @@ -284,8 +303,7 @@ CUDASymmetricMemory::CUDASymmetricMemory( int local_device_idx, int rank, int world_size) - : handles_(std::move(handles)), - block_size_(block_size), + : alloc_refs_(std::move(alloc_refs)), buffers_(std::move(buffers)), signal_pads_(std::move(signal_pads)), mc_handle_(mc_handle), @@ -307,29 +325,6 @@ CUDASymmetricMemory::CUDASymmetricMemory( signal_pads_dev_, signal_pads_.data(), arr_size, cudaMemcpyHostToDevice)); } -CUDASymmetricMemory::~CUDASymmetricMemory() { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - // Leak the cuda allocations during static deinitialization - if (is_finalizing()) { - return; - } - c10::cuda::CUDAGuard guard(local_device_idx_); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - auto driver_api = c10::cuda::DriverAPI::get(); - for (int r = 0; r < world_size_; ++r) { - C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_( - reinterpret_cast(buffers_[r]), block_size_)); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handles_[r])); - } - c10::cuda::CUDACachingAllocator::raw_delete(buffers_dev_); - c10::cuda::CUDACachingAllocator::raw_delete(signal_pads_dev_); -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif -} - std::vector CUDASymmetricMemory::get_buffer_ptrs() { return buffers_; } @@ -592,12 +587,27 @@ int CUDASymmetricMemory::get_world_size() { return world_size_; } +Block::Block( + c10::intrusive_ptr alloc_ref, + int device_idx, + size_t block_size, + size_t buffer_size, + size_t signal_pad_offset, + const std::optional& group_name) + : alloc_ref(std::move(alloc_ref)), + device_idx(device_idx), + block_size(block_size), + buffer_size(buffer_size), + signal_pad_offset(signal_pad_offset), + default_group_name(std::move(group_name)) {} + void* CUDASymmetricMemoryAllocator::alloc( size_t size, int device_idx, - const std::string& group_name) { + const std::optional& group_name) { #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - auto driver_api = c10::cuda::DriverAPI::get(); + c10::cuda::CUDAGuard guard(device_idx); + device_idx = static_cast(guard.current_device().index()); CUmemAllocationProp prop = {}; prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; @@ -610,6 +620,7 @@ void* CUDASymmetricMemoryAllocator::alloc( size_t block_size = signal_pad_offset + signal_pad_size; size_t granularity; + auto driver_api = c10::cuda::DriverAPI::get(); C10_CUDA_DRIVER_CHECK(driver_api->cuMemGetAllocationGranularity_( &granularity, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); block_size = at::round_up(block_size, granularity); @@ -621,11 +632,16 @@ void* CUDASymmetricMemoryAllocator::alloc( void* ptr = nullptr; map_block(&ptr, handle, block_size, device_idx); - c10::cuda::CUDAGuard guard(device_idx); AT_CUDA_CHECK(cudaMemset(ptr, 0, block_size)); + auto alloc_ref = c10::make_intrusive(ptr, handle, block_size, device_idx); auto block = c10::make_intrusive( - handle, device_idx, block_size, size, signal_pad_offset, group_name); + std::move(alloc_ref), + device_idx, + block_size, + size, + signal_pad_offset, + group_name); { std::unique_lock lock(mutex_); ptr_to_block_.emplace(ptr, std::move(block)); @@ -638,28 +654,8 @@ void* CUDASymmetricMemoryAllocator::alloc( } void CUDASymmetricMemoryAllocator::free(void* ptr) { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - auto block = find_block(ptr); - // Leak the cuda allocations during static deinitialization - if (block == nullptr || is_finalizing()) { - return; - } - // Initializing CUDASymmetricMemory with an allocation transfers its - // ownership to the CUDASymmetricMemory object. - if (block->symm_mem == nullptr) { - auto driver_api = c10::cuda::DriverAPI::get(); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_( - reinterpret_cast(ptr), block->block_size)); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(block->handle)); - } - { - std::unique_lock lock(mutex_); - ptr_to_block_.erase(ptr); - } -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif + std::unique_lock lock(mutex_); + ptr_to_block_.erase(ptr); } size_t CUDASymmetricMemoryAllocator::get_alloc_size(void* ptr) { @@ -785,7 +781,7 @@ static void init_multicast_for_block( C10_CUDA_DRIVER_CHECK( driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx)); C10_CUDA_DRIVER_CHECK(driver_api->cuMulticastBindMem_( - mc_handle, 0, block->handle, 0, block->block_size, 0)); + mc_handle, 0, block->alloc_ref->handle, 0, block->block_size, 0)); map_block(&mc_addr, mc_handle, block->block_size, block->device_idx); store_barrier(store, rank, world_size); @@ -793,19 +789,36 @@ static void init_multicast_for_block( } c10::intrusive_ptr CUDASymmetricMemoryAllocator::rendezvous( - void* ptr) { + void* ptr, + const std::optional& group_name) { #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) auto block = find_block(ptr); if (block == nullptr) { return nullptr; } - if (block->symm_mem != nullptr) { - return block->symm_mem; + // The group_name passed to rendezvous() takes precedence over + // the default group_name specified during allocation. + std::string group_name_; + if (group_name.has_value()) { + group_name_ = *group_name; + } else { + if (!block->default_group_name.has_value()) { + TORCH_CHECK( + false, + "CUDASymmetricMemory::rendezvous: `group_name` is neither " + "specified during allocation nor passed to rendezvous()."); + } + group_name_ = *block->default_group_name; + } + + auto it = block->symm_mems.find(group_name_); + if (it != block->symm_mems.end()) { + return it->second; } IpcChannel ipc_channel; - auto group_info = get_group_info(block->group_name); + auto group_info = get_group_info(group_name_); auto store = group_info.store; int rank = group_info.rank; int world_size = group_info.world_size; @@ -813,7 +826,10 @@ c10::intrusive_ptr CUDASymmetricMemoryAllocator::rendezvous( auto driver_api = c10::cuda::DriverAPI::get(); int block_fd; C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_( - &block_fd, block->handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0)); + &block_fd, + block->alloc_ref->handle, + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, + 0)); auto local_req = RendezvousRequest{ .device_idx = block->device_idx, @@ -837,7 +853,7 @@ c10::intrusive_ptr CUDASymmetricMemoryAllocator::rendezvous( for (int r = 0; r < world_size; ++r) { if (r == rank) { - handles[r] = block->handle; + handles[r] = block->alloc_ref->handle; buffers[r] = ptr; signal_pads[r] = (void*)((uintptr_t)ptr + block->signal_pad_offset); continue; @@ -861,13 +877,18 @@ c10::intrusive_ptr CUDASymmetricMemoryAllocator::rendezvous( mc_handle, mc_addr, block, ipc_channel, pids, store, rank, world_size); } - // Initializing CUDASymmetricMemory with an allocation transfers its - // ownership to the CUDASymmetricMemory object. So that outstanding - // references to the CUDASymmetricMemory object can keep the allocation - // alive. - block->symm_mem = c10::make_intrusive( - std::move(handles), - block->block_size, + std::vector> alloc_refs; + for (int r = 0; r < world_size; ++r) { + if (r == rank) { + alloc_refs.emplace_back(block->alloc_ref); + continue; + } + alloc_refs.push_back(c10::make_intrusive( + buffers[r], handles[r], block->block_size, block->device_idx)); + } + + auto symm_mem = c10::make_intrusive( + std::move(alloc_refs), std::move(buffers), std::move(signal_pads), mc_handle, @@ -876,22 +897,14 @@ c10::intrusive_ptr CUDASymmetricMemoryAllocator::rendezvous( block->device_idx, group_info.rank, group_info.world_size); - return block->symm_mem; + block->symm_mems[group_name_] = symm_mem; + return symm_mem; #else TORCH_CHECK( false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); #endif } -bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) { - auto block = find_block(ptr); - TORCH_CHECK( - block != nullptr, - "CUDASymmetricMemoryAllocator::is_rendezvous_completed: input must be allocated ", - "via CUDASymmetricMemoryAllocator::alloc"); - return block->symm_mem != nullptr; -} - bool CUDASymmetricMemoryAllocator::has_multicast_support(int device_idx) { return device_has_multicast_support(device_idx); } diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp b/torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp index 0e08591f946..3b32dbb4fd8 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp @@ -12,11 +12,27 @@ using HandleType = CUmemGenericAllocationHandle; using HandleType = void*; #endif +// Resource wrapper that owns a (vaddr, allocation handle) pair. Upon +// destruction, it unmaps the vaddr and releases the allocation handle. +struct AllocationRef : public c10::intrusive_ptr_target { + void* ptr; + HandleType handle; + size_t block_size; + int device_idx; + + AllocationRef( + void* ptr, + HandleType handle, + size_t block_size, + int device_idx); + + ~AllocationRef(); +}; + class CUDASymmetricMemory : public SymmetricMemory { public: CUDASymmetricMemory( - std::vector handles, - size_t block_size, + std::vector> alloc_refs, std::vector buffers, std::vector signal_pads, HandleType mc_handle, @@ -26,7 +42,7 @@ class CUDASymmetricMemory : public SymmetricMemory { int rank, int world_size); - ~CUDASymmetricMemory() override; + ~CUDASymmetricMemory() override{}; std::vector get_buffer_ptrs() override; std::vector get_signal_pad_ptrs() override; @@ -58,8 +74,7 @@ class CUDASymmetricMemory : public SymmetricMemory { int get_world_size() override; private: - std::vector handles_; - size_t block_size_; + std::vector> alloc_refs_; std::vector buffers_; std::vector signal_pads_; HandleType mc_handle_; @@ -70,43 +85,40 @@ class CUDASymmetricMemory : public SymmetricMemory { int world_size_; void** buffers_dev_; void** signal_pads_dev_; - std::optional> finalizer_; }; +// Metadata associated with each allocation performed by +// `CUDASymmetricMemoryAllocator`. struct Block : public c10::intrusive_ptr_target { - HandleType handle; + c10::intrusive_ptr alloc_ref; int device_idx; size_t block_size; size_t buffer_size; size_t signal_pad_offset; - std::string group_name; - c10::intrusive_ptr symm_mem = nullptr; + std::optional default_group_name; + std::map> symm_mems; Block( - HandleType handle, + c10::intrusive_ptr alloc_ref, int device_idx, size_t block_size, size_t buffer_size, size_t signal_pad_offset, - std::string group_name) - : handle(handle), - device_idx(device_idx), - block_size(block_size), - buffer_size(buffer_size), - signal_pad_offset(signal_pad_offset), - group_name(std::move(group_name)), - symm_mem(nullptr) {} + const std::optional& group_name); }; class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator { public: - void* alloc(size_t size, int device_idx, const std::string& group_name) - override; + void* alloc( + size_t size, + int device_idx, + const std::optional& group_name) override; void free(void* ptr) override; size_t get_alloc_size(void* ptr) override; - c10::intrusive_ptr rendezvous(void* ptr) override; - bool is_rendezvous_completed(void* ptr) override; + c10::intrusive_ptr rendezvous( + void* ptr, + const std::optional& group_name) override; bool has_multicast_support(int device_idx) override; private: diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/SymmetricMemory.cpp index 7911a9d875b..b37186b2d37 100644 --- a/torch/csrc/distributed/c10d/SymmetricMemory.cpp +++ b/torch/csrc/distributed/c10d/SymmetricMemory.cpp @@ -56,7 +56,7 @@ static at::Tensor empty_strided_p2p_persistent( c10::IntArrayRef stride, c10::ScalarType dtype, c10::Device device, - const std::string& group_name, + const std::optional& group_name, uint64_t alloc_id) { // Make the allocation fails if a previous allocation with the same alloc_id // is still active. @@ -153,7 +153,7 @@ at::Tensor empty_strided_p2p( c10::IntArrayRef stride, c10::ScalarType dtype, c10::Device device, - const std::string& group_name, + const std::optional& group_name, std::optional alloc_id) { if (alloc_id.has_value()) { return empty_strided_p2p_persistent( @@ -181,19 +181,10 @@ at::Tensor empty_strided_p2p( } TORCH_API c10::intrusive_ptr rendezvous( - const at::Tensor& tensor) { + const at::Tensor& tensor, + const std::optional& group_name) { auto allocator = get_allocator(tensor.device().type()); - return allocator->rendezvous(tensor.storage().data_ptr().get()); -} - -c10::intrusive_ptr get_symmetric_memory( - const at::Tensor& tensor) { - auto allocator = get_allocator(tensor.device().type()); - TORCH_CHECK( - allocator->is_rendezvous_completed(tensor.data_ptr()), - "SymmetricMemory: must invoke rendezvous on a tensor ", - "before calling get_symmetric_memory on it"); - return allocator->rendezvous(tensor.data_ptr()); + return allocator->rendezvous(tensor.storage().data_ptr().get(), group_name); } TORCH_API bool has_multicast_support( diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.hpp b/torch/csrc/distributed/c10d/SymmetricMemory.hpp index 72d6a132ab4..982b8d3449e 100644 --- a/torch/csrc/distributed/c10d/SymmetricMemory.hpp +++ b/torch/csrc/distributed/c10d/SymmetricMemory.hpp @@ -80,12 +80,13 @@ class SymmetricMemoryAllocator : public c10::intrusive_ptr_target { virtual void* alloc( size_t size, int device_idx, - const std::string& group_name) = 0; + const std::optional& group_name) = 0; virtual void free(void* ptr) = 0; virtual size_t get_alloc_size(void* ptr) = 0; - virtual c10::intrusive_ptr rendezvous(void* ptr) = 0; - virtual bool is_rendezvous_completed(void* ptr) = 0; + virtual c10::intrusive_ptr rendezvous( + void* ptr, + const std::optional& group_name) = 0; virtual bool has_multicast_support(int device_idx) = 0; }; @@ -138,7 +139,7 @@ TORCH_API at::Tensor empty_strided_p2p( c10::IntArrayRef stride, c10::ScalarType dtype, c10::Device device, - const std::string& group_name, + const std::optional& group_name, std::optional alloc_id); // Establishes symmetric memory access on tensors allocated via @@ -152,12 +153,8 @@ TORCH_API at::Tensor empty_strided_p2p( // The function has a collective semantic and must be invoked simultaneously // from all rendezvous participants. TORCH_API c10::intrusive_ptr rendezvous( - const at::Tensor& tensor); - -// Returns the SymmetricMemory object associated with the tensor. It can only -// be invoked after rendezvous() but does not need to be invoked collectively. -TORCH_API c10::intrusive_ptr get_symmetric_memory( - const at::Tensor& tensor); + const at::Tensor& tensor, + const std::optional& group_name = std::nullopt); TORCH_API bool has_multicast_support( c10::DeviceType device_type, diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 6f42039e3ab..bbcfd10e58e 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1057,12 +1057,13 @@ This class does not support ``__members__`` property.)"); py::arg("stride"), py::arg("dtype"), py::arg("device"), - py::arg("group_name"), + py::arg("group_name") = py::none(), py::arg("alloc_id") = py::none()) - .def_static("rendezvous", &::c10d::symmetric_memory::rendezvous) .def_static( - "get_symmetric_memory", - &::c10d::symmetric_memory::get_symmetric_memory) + "rendezvous", + &::c10d::symmetric_memory::rendezvous, + py::arg("tensor"), + py::arg("group_name") = py::none()) .def_static( "has_multicast_support", &::c10d::symmetric_memory::has_multicast_support) diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cpp b/torch/csrc/distributed/c10d/intra_node_comm.cpp index c0c53d220d8..e4c4c491f7e 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.cpp +++ b/torch/csrc/distributed/c10d/intra_node_comm.cpp @@ -176,8 +176,8 @@ bool IntraNodeComm::rendezvous() { set_group_info( groupName, static_cast(rank_), static_cast(worldSize_), store_); auto allocator = get_allocator(c10::DeviceType::CUDA); - symmetricMemoryPtr_ = allocator->alloc(bufferSize_, deviceIdx_, groupName); - symmetricMemory_ = allocator->rendezvous(symmetricMemoryPtr_); + symmetricMemoryPtr_ = allocator->alloc(bufferSize_, deviceIdx_, std::nullopt); + symmetricMemory_ = allocator->rendezvous(symmetricMemoryPtr_, groupName); isInitialized_ = true; return true; #endif