[SymmetricMemory] support specifying group_name at rendezvous time (#139529)

Before this PR, users need to call `empty_strided_p2p()` with a `group_name`:

```python
tensor = _SymmetricMemory.empty_strided_p2p((1024,), (1,), device=device, group_name="0")
symm_mem = _SymmetricMemory.rendezvous(tensor)
```

Users can now omit `group_name` at allocation time and specify it later at rendezvous time:

```python
tensor = _SymmetricMemory.empty_strided_p2p((1024,), (1,), device=device)
symm_mem = _SymmetricMemory.rendezvous(tensor, group_name="0")
```

Rationales for this change:
- This allows the same allocation to establish symmetric memory under different groups
- Specifying `group_name` at rendezvous time instead of allocation time is a more natural UX

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139529
Approved by: https://github.com/lw
This commit is contained in:
Yifu Wang 2024-11-16 20:36:51 -08:00 committed by PyTorch MergeBot
parent 602ae9cbcf
commit ab5c8857ef
8 changed files with 225 additions and 143 deletions

View file

@ -97,7 +97,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
buf = symm_mem.get_buffer(0, (symm_mem.buffer_size // 4,), torch.float32)
self.assertEqual(buf.storage_offset(), 0)
self.assertEqual(buf.storage().size(), symm_mem.buffer_size // 4)
self.assertEqual(buf.untyped_storage().size(), symm_mem.buffer_size)
if symm_mem.rank == 0:
symm_mem.wait_signal(src_rank=1)
@ -168,16 +168,8 @@ class SymmetricMemoryTest(MultiProcessTestCase):
t = _SymmetricMemory.empty_strided_p2p(*alloc_args, alloc_id=42)
self.assertEqual(t.data_ptr(), data_ptr)
# Verify that get_symmetric_memory would fail if called before
# rendezvous.
with self.assertRaises(RuntimeError):
_SymmetricMemory.get_symmetric_memory(t)
symm_mem_0 = _SymmetricMemory.rendezvous(t)
symm_mem_1 = _SymmetricMemory.get_symmetric_memory(t)
self.assertEqual(id(symm_mem_0), id(symm_mem_1))
self._verify_symmetric_memory(symm_mem_0)
symm_mem = _SymmetricMemory.rendezvous(t)
self._verify_symmetric_memory(symm_mem)
dist.destroy_process_group()
@skipIfRocm
@ -669,6 +661,79 @@ class SymmetricMemoryTest(MultiProcessTestCase):
dist.destroy_process_group()
@instantiate_parametrized_tests
@requires_cuda_p2p_access()
class SubgroupTest(MultiProcessTestCase):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
@property
def world_size(self) -> int:
return 4
@property
def device(self) -> torch.device:
return torch.device(f"cuda:{self.rank}")
def _init_process(self):
torch.cuda.set_device(self.device)
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend="nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
enable_symm_mem_for_group(dist.group.WORLD.group_name)
torch.manual_seed(42 + self.rank)
@skipIfRocm
@skip_if_lt_x_gpu(4)
def test_subgroup(self) -> None:
self._init_process()
ranks = list(range(self.world_size))
subgroup_0 = dist.new_group(ranks[: len(ranks) // 2])
subgroup_1 = dist.new_group(ranks[len(ranks) // 2 :])
world = dist.group.WORLD
subgroup = subgroup_0 if world.rank() < world.size() // 2 else subgroup_1
enable_symm_mem_for_group(subgroup.group_name)
t = _SymmetricMemory.empty_strided_p2p(
size=(64,),
stride=(1,),
dtype=torch.float32,
device=self.device,
)
symm_mem_world = _SymmetricMemory.rendezvous(t, group_name=world.group_name)
symm_mem_subgroup = _SymmetricMemory.rendezvous(
t, group_name=subgroup.group_name
)
self.assertEqual(symm_mem_world.world_size, world.size())
self.assertEqual(symm_mem_world.rank, world.rank())
self.assertEqual(symm_mem_subgroup.world_size, world.size() // 2)
self.assertEqual(symm_mem_subgroup.rank, world.rank() % subgroup.size())
t.fill_(world.rank())
symm_mem_world.barrier()
# Observe a peer buffer via the world group
peer_rank = (world.rank() + 1) % world.size()
buf = symm_mem_world.get_buffer(peer_rank, (64,), torch.float32)
self.assertTrue(buf.eq(peer_rank).all())
# Observe a peer buffer via the subgroup
peer_rank = (subgroup.rank() + 1) % subgroup.size()
buf = symm_mem_subgroup.get_buffer(peer_rank, (64,), torch.float32)
if world.rank() < world.size() // 2:
self.assertTrue(buf.eq(peer_rank).all())
else:
self.assertTrue(buf.eq(peer_rank + world.size() // 2).all())
@instantiate_parametrized_tests
@requires_cuda_p2p_access()
class SymmMemAllReduceTest(MultiProcessTestCase):

View file

@ -662,14 +662,17 @@ class _SymmetricMemory:
stride: torch.types._size,
dtype: torch.dtype,
device: torch.device,
group_name: str,
group_name: str | None = None,
alloc_id: int | None = None,
) -> torch.Tensor: ...
@property
def rank(self) -> int: ...
@property
def world_size(self) -> int: ...
@staticmethod
def rendezvous(tensor: torch.Tensor) -> _SymmetricMemory: ...
def rendezvous(
tensor: torch.Tensor, group_name: str | None = None
) -> _SymmetricMemory: ...
def get_buffer(
self,
rank: int,

View file

@ -95,7 +95,6 @@ class IpcChannel {
cmsg->cmsg_type = SCM_RIGHTS;
if (fd != -1) {
// memcpy(CMSG_DATA(cmsg), &fd, sizeof(fd));
std::copy(
reinterpret_cast<const char*>(&fd),
reinterpret_cast<const char*>(&fd) + sizeof(fd),
@ -273,9 +272,29 @@ void map_block(
namespace c10d {
namespace symmetric_memory {
AllocationRef::AllocationRef(void* ptr, HandleType handle, size_t block_size, int device_idx)
: ptr(ptr), handle(handle), block_size(block_size), device_idx(device_idx) {}
AllocationRef::~AllocationRef() {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
// Leak the cuda allocations during static deinitialization
if (is_finalizing()) {
return;
}
auto driver_api = c10::cuda::DriverAPI::get();
c10::cuda::CUDAGuard guard(device_idx);
C10_CUDA_CHECK(cudaDeviceSynchronize());
C10_CUDA_DRIVER_CHECK(
driver_api->cuMemUnmap_(reinterpret_cast<CUdeviceptr>(ptr), block_size));
C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handle));
#else
TORCH_CHECK(
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
#endif
}
CUDASymmetricMemory::CUDASymmetricMemory(
std::vector<HandleType> handles,
size_t block_size,
std::vector<c10::intrusive_ptr<AllocationRef>> alloc_refs,
std::vector<void*> buffers,
std::vector<void*> signal_pads,
HandleType mc_handle,
@ -284,8 +303,7 @@ CUDASymmetricMemory::CUDASymmetricMemory(
int local_device_idx,
int rank,
int world_size)
: handles_(std::move(handles)),
block_size_(block_size),
: alloc_refs_(std::move(alloc_refs)),
buffers_(std::move(buffers)),
signal_pads_(std::move(signal_pads)),
mc_handle_(mc_handle),
@ -307,29 +325,6 @@ CUDASymmetricMemory::CUDASymmetricMemory(
signal_pads_dev_, signal_pads_.data(), arr_size, cudaMemcpyHostToDevice));
}
CUDASymmetricMemory::~CUDASymmetricMemory() {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
// Leak the cuda allocations during static deinitialization
if (is_finalizing()) {
return;
}
c10::cuda::CUDAGuard guard(local_device_idx_);
C10_CUDA_CHECK(cudaDeviceSynchronize());
auto driver_api = c10::cuda::DriverAPI::get();
for (int r = 0; r < world_size_; ++r) {
C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_(
reinterpret_cast<CUdeviceptr>(buffers_[r]), block_size_));
C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handles_[r]));
}
c10::cuda::CUDACachingAllocator::raw_delete(buffers_dev_);
c10::cuda::CUDACachingAllocator::raw_delete(signal_pads_dev_);
#else
TORCH_CHECK(
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
#endif
}
std::vector<void*> CUDASymmetricMemory::get_buffer_ptrs() {
return buffers_;
}
@ -592,12 +587,27 @@ int CUDASymmetricMemory::get_world_size() {
return world_size_;
}
Block::Block(
c10::intrusive_ptr<AllocationRef> alloc_ref,
int device_idx,
size_t block_size,
size_t buffer_size,
size_t signal_pad_offset,
const std::optional<std::string>& group_name)
: alloc_ref(std::move(alloc_ref)),
device_idx(device_idx),
block_size(block_size),
buffer_size(buffer_size),
signal_pad_offset(signal_pad_offset),
default_group_name(std::move(group_name)) {}
void* CUDASymmetricMemoryAllocator::alloc(
size_t size,
int device_idx,
const std::string& group_name) {
const std::optional<std::string>& group_name) {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
auto driver_api = c10::cuda::DriverAPI::get();
c10::cuda::CUDAGuard guard(device_idx);
device_idx = static_cast<int>(guard.current_device().index());
CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
@ -610,6 +620,7 @@ void* CUDASymmetricMemoryAllocator::alloc(
size_t block_size = signal_pad_offset + signal_pad_size;
size_t granularity;
auto driver_api = c10::cuda::DriverAPI::get();
C10_CUDA_DRIVER_CHECK(driver_api->cuMemGetAllocationGranularity_(
&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
block_size = at::round_up(block_size, granularity);
@ -621,11 +632,16 @@ void* CUDASymmetricMemoryAllocator::alloc(
void* ptr = nullptr;
map_block(&ptr, handle, block_size, device_idx);
c10::cuda::CUDAGuard guard(device_idx);
AT_CUDA_CHECK(cudaMemset(ptr, 0, block_size));
auto alloc_ref = c10::make_intrusive<AllocationRef>(ptr, handle, block_size, device_idx);
auto block = c10::make_intrusive<Block>(
handle, device_idx, block_size, size, signal_pad_offset, group_name);
std::move(alloc_ref),
device_idx,
block_size,
size,
signal_pad_offset,
group_name);
{
std::unique_lock lock(mutex_);
ptr_to_block_.emplace(ptr, std::move(block));
@ -638,28 +654,8 @@ void* CUDASymmetricMemoryAllocator::alloc(
}
void CUDASymmetricMemoryAllocator::free(void* ptr) {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
auto block = find_block(ptr);
// Leak the cuda allocations during static deinitialization
if (block == nullptr || is_finalizing()) {
return;
}
// Initializing CUDASymmetricMemory with an allocation transfers its
// ownership to the CUDASymmetricMemory object.
if (block->symm_mem == nullptr) {
auto driver_api = c10::cuda::DriverAPI::get();
C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_(
reinterpret_cast<CUdeviceptr>(ptr), block->block_size));
C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(block->handle));
}
{
std::unique_lock lock(mutex_);
ptr_to_block_.erase(ptr);
}
#else
TORCH_CHECK(
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
#endif
std::unique_lock lock(mutex_);
ptr_to_block_.erase(ptr);
}
size_t CUDASymmetricMemoryAllocator::get_alloc_size(void* ptr) {
@ -785,7 +781,7 @@ static void init_multicast_for_block(
C10_CUDA_DRIVER_CHECK(
driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx));
C10_CUDA_DRIVER_CHECK(driver_api->cuMulticastBindMem_(
mc_handle, 0, block->handle, 0, block->block_size, 0));
mc_handle, 0, block->alloc_ref->handle, 0, block->block_size, 0));
map_block(&mc_addr, mc_handle, block->block_size, block->device_idx);
store_barrier(store, rank, world_size);
@ -793,19 +789,36 @@ static void init_multicast_for_block(
}
c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
void* ptr) {
void* ptr,
const std::optional<std::string>& group_name) {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
auto block = find_block(ptr);
if (block == nullptr) {
return nullptr;
}
if (block->symm_mem != nullptr) {
return block->symm_mem;
// The group_name passed to rendezvous() takes precedence over
// the default group_name specified during allocation.
std::string group_name_;
if (group_name.has_value()) {
group_name_ = *group_name;
} else {
if (!block->default_group_name.has_value()) {
TORCH_CHECK(
false,
"CUDASymmetricMemory::rendezvous: `group_name` is neither "
"specified during allocation nor passed to rendezvous().");
}
group_name_ = *block->default_group_name;
}
auto it = block->symm_mems.find(group_name_);
if (it != block->symm_mems.end()) {
return it->second;
}
IpcChannel ipc_channel;
auto group_info = get_group_info(block->group_name);
auto group_info = get_group_info(group_name_);
auto store = group_info.store;
int rank = group_info.rank;
int world_size = group_info.world_size;
@ -813,7 +826,10 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
auto driver_api = c10::cuda::DriverAPI::get();
int block_fd;
C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_(
&block_fd, block->handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0));
&block_fd,
block->alloc_ref->handle,
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
0));
auto local_req = RendezvousRequest{
.device_idx = block->device_idx,
@ -837,7 +853,7 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
for (int r = 0; r < world_size; ++r) {
if (r == rank) {
handles[r] = block->handle;
handles[r] = block->alloc_ref->handle;
buffers[r] = ptr;
signal_pads[r] = (void*)((uintptr_t)ptr + block->signal_pad_offset);
continue;
@ -861,13 +877,18 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
mc_handle, mc_addr, block, ipc_channel, pids, store, rank, world_size);
}
// Initializing CUDASymmetricMemory with an allocation transfers its
// ownership to the CUDASymmetricMemory object. So that outstanding
// references to the CUDASymmetricMemory object can keep the allocation
// alive.
block->symm_mem = c10::make_intrusive<CUDASymmetricMemory>(
std::move(handles),
block->block_size,
std::vector<c10::intrusive_ptr<AllocationRef>> alloc_refs;
for (int r = 0; r < world_size; ++r) {
if (r == rank) {
alloc_refs.emplace_back(block->alloc_ref);
continue;
}
alloc_refs.push_back(c10::make_intrusive<AllocationRef>(
buffers[r], handles[r], block->block_size, block->device_idx));
}
auto symm_mem = c10::make_intrusive<CUDASymmetricMemory>(
std::move(alloc_refs),
std::move(buffers),
std::move(signal_pads),
mc_handle,
@ -876,22 +897,14 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
block->device_idx,
group_info.rank,
group_info.world_size);
return block->symm_mem;
block->symm_mems[group_name_] = symm_mem;
return symm_mem;
#else
TORCH_CHECK(
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
#endif
}
bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) {
auto block = find_block(ptr);
TORCH_CHECK(
block != nullptr,
"CUDASymmetricMemoryAllocator::is_rendezvous_completed: input must be allocated ",
"via CUDASymmetricMemoryAllocator::alloc");
return block->symm_mem != nullptr;
}
bool CUDASymmetricMemoryAllocator::has_multicast_support(int device_idx) {
return device_has_multicast_support(device_idx);
}

View file

@ -12,11 +12,27 @@ using HandleType = CUmemGenericAllocationHandle;
using HandleType = void*;
#endif
// Resource wrapper that owns a (vaddr, allocation handle) pair. Upon
// destruction, it unmaps the vaddr and releases the allocation handle.
struct AllocationRef : public c10::intrusive_ptr_target {
void* ptr;
HandleType handle;
size_t block_size;
int device_idx;
AllocationRef(
void* ptr,
HandleType handle,
size_t block_size,
int device_idx);
~AllocationRef();
};
class CUDASymmetricMemory : public SymmetricMemory {
public:
CUDASymmetricMemory(
std::vector<HandleType> handles,
size_t block_size,
std::vector<c10::intrusive_ptr<AllocationRef>> alloc_refs,
std::vector<void*> buffers,
std::vector<void*> signal_pads,
HandleType mc_handle,
@ -26,7 +42,7 @@ class CUDASymmetricMemory : public SymmetricMemory {
int rank,
int world_size);
~CUDASymmetricMemory() override;
~CUDASymmetricMemory() override{};
std::vector<void*> get_buffer_ptrs() override;
std::vector<void*> get_signal_pad_ptrs() override;
@ -58,8 +74,7 @@ class CUDASymmetricMemory : public SymmetricMemory {
int get_world_size() override;
private:
std::vector<HandleType> handles_;
size_t block_size_;
std::vector<c10::intrusive_ptr<AllocationRef>> alloc_refs_;
std::vector<void*> buffers_;
std::vector<void*> signal_pads_;
HandleType mc_handle_;
@ -70,43 +85,40 @@ class CUDASymmetricMemory : public SymmetricMemory {
int world_size_;
void** buffers_dev_;
void** signal_pads_dev_;
std::optional<std::function<void(void)>> finalizer_;
};
// Metadata associated with each allocation performed by
// `CUDASymmetricMemoryAllocator`.
struct Block : public c10::intrusive_ptr_target {
HandleType handle;
c10::intrusive_ptr<AllocationRef> alloc_ref;
int device_idx;
size_t block_size;
size_t buffer_size;
size_t signal_pad_offset;
std::string group_name;
c10::intrusive_ptr<CUDASymmetricMemory> symm_mem = nullptr;
std::optional<std::string> default_group_name;
std::map<std::string, c10::intrusive_ptr<CUDASymmetricMemory>> symm_mems;
Block(
HandleType handle,
c10::intrusive_ptr<AllocationRef> alloc_ref,
int device_idx,
size_t block_size,
size_t buffer_size,
size_t signal_pad_offset,
std::string group_name)
: handle(handle),
device_idx(device_idx),
block_size(block_size),
buffer_size(buffer_size),
signal_pad_offset(signal_pad_offset),
group_name(std::move(group_name)),
symm_mem(nullptr) {}
const std::optional<std::string>& group_name);
};
class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator {
public:
void* alloc(size_t size, int device_idx, const std::string& group_name)
override;
void* alloc(
size_t size,
int device_idx,
const std::optional<std::string>& group_name) override;
void free(void* ptr) override;
size_t get_alloc_size(void* ptr) override;
c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) override;
bool is_rendezvous_completed(void* ptr) override;
c10::intrusive_ptr<SymmetricMemory> rendezvous(
void* ptr,
const std::optional<std::string>& group_name) override;
bool has_multicast_support(int device_idx) override;
private:

View file

@ -56,7 +56,7 @@ static at::Tensor empty_strided_p2p_persistent(
c10::IntArrayRef stride,
c10::ScalarType dtype,
c10::Device device,
const std::string& group_name,
const std::optional<std::string>& group_name,
uint64_t alloc_id) {
// Make the allocation fails if a previous allocation with the same alloc_id
// is still active.
@ -153,7 +153,7 @@ at::Tensor empty_strided_p2p(
c10::IntArrayRef stride,
c10::ScalarType dtype,
c10::Device device,
const std::string& group_name,
const std::optional<std::string>& group_name,
std::optional<uint64_t> alloc_id) {
if (alloc_id.has_value()) {
return empty_strided_p2p_persistent(
@ -181,19 +181,10 @@ at::Tensor empty_strided_p2p(
}
TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
const at::Tensor& tensor) {
const at::Tensor& tensor,
const std::optional<std::string>& group_name) {
auto allocator = get_allocator(tensor.device().type());
return allocator->rendezvous(tensor.storage().data_ptr().get());
}
c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
const at::Tensor& tensor) {
auto allocator = get_allocator(tensor.device().type());
TORCH_CHECK(
allocator->is_rendezvous_completed(tensor.data_ptr()),
"SymmetricMemory: must invoke rendezvous on a tensor ",
"before calling get_symmetric_memory on it");
return allocator->rendezvous(tensor.data_ptr());
return allocator->rendezvous(tensor.storage().data_ptr().get(), group_name);
}
TORCH_API bool has_multicast_support(

View file

@ -80,12 +80,13 @@ class SymmetricMemoryAllocator : public c10::intrusive_ptr_target {
virtual void* alloc(
size_t size,
int device_idx,
const std::string& group_name) = 0;
const std::optional<std::string>& group_name) = 0;
virtual void free(void* ptr) = 0;
virtual size_t get_alloc_size(void* ptr) = 0;
virtual c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) = 0;
virtual bool is_rendezvous_completed(void* ptr) = 0;
virtual c10::intrusive_ptr<SymmetricMemory> rendezvous(
void* ptr,
const std::optional<std::string>& group_name) = 0;
virtual bool has_multicast_support(int device_idx) = 0;
};
@ -138,7 +139,7 @@ TORCH_API at::Tensor empty_strided_p2p(
c10::IntArrayRef stride,
c10::ScalarType dtype,
c10::Device device,
const std::string& group_name,
const std::optional<std::string>& group_name,
std::optional<uint64_t> alloc_id);
// Establishes symmetric memory access on tensors allocated via
@ -152,12 +153,8 @@ TORCH_API at::Tensor empty_strided_p2p(
// The function has a collective semantic and must be invoked simultaneously
// from all rendezvous participants.
TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
const at::Tensor& tensor);
// Returns the SymmetricMemory object associated with the tensor. It can only
// be invoked after rendezvous() but does not need to be invoked collectively.
TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
const at::Tensor& tensor);
const at::Tensor& tensor,
const std::optional<std::string>& group_name = std::nullopt);
TORCH_API bool has_multicast_support(
c10::DeviceType device_type,

View file

@ -1057,12 +1057,13 @@ This class does not support ``__members__`` property.)");
py::arg("stride"),
py::arg("dtype"),
py::arg("device"),
py::arg("group_name"),
py::arg("group_name") = py::none(),
py::arg("alloc_id") = py::none())
.def_static("rendezvous", &::c10d::symmetric_memory::rendezvous)
.def_static(
"get_symmetric_memory",
&::c10d::symmetric_memory::get_symmetric_memory)
"rendezvous",
&::c10d::symmetric_memory::rendezvous,
py::arg("tensor"),
py::arg("group_name") = py::none())
.def_static(
"has_multicast_support",
&::c10d::symmetric_memory::has_multicast_support)

View file

@ -176,8 +176,8 @@ bool IntraNodeComm::rendezvous() {
set_group_info(
groupName, static_cast<int>(rank_), static_cast<int>(worldSize_), store_);
auto allocator = get_allocator(c10::DeviceType::CUDA);
symmetricMemoryPtr_ = allocator->alloc(bufferSize_, deviceIdx_, groupName);
symmetricMemory_ = allocator->rendezvous(symmetricMemoryPtr_);
symmetricMemoryPtr_ = allocator->alloc(bufferSize_, deviceIdx_, std::nullopt);
symmetricMemory_ = allocator->rendezvous(symmetricMemoryPtr_, groupName);
isInitialized_ = true;
return true;
#endif