mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[SymmetricMemory] support specifying group_name at rendezvous time (#139529)
Before this PR, users need to call `empty_strided_p2p()` with a `group_name`: ```python tensor = _SymmetricMemory.empty_strided_p2p((1024,), (1,), device=device, group_name="0") symm_mem = _SymmetricMemory.rendezvous(tensor) ``` Users can now omit `group_name` at allocation time and specify it later at rendezvous time: ```python tensor = _SymmetricMemory.empty_strided_p2p((1024,), (1,), device=device) symm_mem = _SymmetricMemory.rendezvous(tensor, group_name="0") ``` Rationales for this change: - This allows the same allocation to establish symmetric memory under different groups - Specifying `group_name` at rendezvous time instead of allocation time is a more natural UX Pull Request resolved: https://github.com/pytorch/pytorch/pull/139529 Approved by: https://github.com/lw
This commit is contained in:
parent
602ae9cbcf
commit
ab5c8857ef
8 changed files with 225 additions and 143 deletions
|
|
@ -97,7 +97,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
|||
|
||||
buf = symm_mem.get_buffer(0, (symm_mem.buffer_size // 4,), torch.float32)
|
||||
self.assertEqual(buf.storage_offset(), 0)
|
||||
self.assertEqual(buf.storage().size(), symm_mem.buffer_size // 4)
|
||||
self.assertEqual(buf.untyped_storage().size(), symm_mem.buffer_size)
|
||||
|
||||
if symm_mem.rank == 0:
|
||||
symm_mem.wait_signal(src_rank=1)
|
||||
|
|
@ -168,16 +168,8 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
|||
t = _SymmetricMemory.empty_strided_p2p(*alloc_args, alloc_id=42)
|
||||
self.assertEqual(t.data_ptr(), data_ptr)
|
||||
|
||||
# Verify that get_symmetric_memory would fail if called before
|
||||
# rendezvous.
|
||||
with self.assertRaises(RuntimeError):
|
||||
_SymmetricMemory.get_symmetric_memory(t)
|
||||
|
||||
symm_mem_0 = _SymmetricMemory.rendezvous(t)
|
||||
symm_mem_1 = _SymmetricMemory.get_symmetric_memory(t)
|
||||
self.assertEqual(id(symm_mem_0), id(symm_mem_1))
|
||||
|
||||
self._verify_symmetric_memory(symm_mem_0)
|
||||
symm_mem = _SymmetricMemory.rendezvous(t)
|
||||
self._verify_symmetric_memory(symm_mem)
|
||||
dist.destroy_process_group()
|
||||
|
||||
@skipIfRocm
|
||||
|
|
@ -669,6 +661,79 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
|||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
@requires_cuda_p2p_access()
|
||||
class SubgroupTest(MultiProcessTestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 4
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(f"cuda:{self.rank}")
|
||||
|
||||
def _init_process(self):
|
||||
torch.cuda.set_device(self.device)
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
enable_symm_mem_for_group(dist.group.WORLD.group_name)
|
||||
torch.manual_seed(42 + self.rank)
|
||||
|
||||
@skipIfRocm
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_subgroup(self) -> None:
|
||||
self._init_process()
|
||||
|
||||
ranks = list(range(self.world_size))
|
||||
subgroup_0 = dist.new_group(ranks[: len(ranks) // 2])
|
||||
subgroup_1 = dist.new_group(ranks[len(ranks) // 2 :])
|
||||
|
||||
world = dist.group.WORLD
|
||||
subgroup = subgroup_0 if world.rank() < world.size() // 2 else subgroup_1
|
||||
enable_symm_mem_for_group(subgroup.group_name)
|
||||
|
||||
t = _SymmetricMemory.empty_strided_p2p(
|
||||
size=(64,),
|
||||
stride=(1,),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
symm_mem_world = _SymmetricMemory.rendezvous(t, group_name=world.group_name)
|
||||
symm_mem_subgroup = _SymmetricMemory.rendezvous(
|
||||
t, group_name=subgroup.group_name
|
||||
)
|
||||
|
||||
self.assertEqual(symm_mem_world.world_size, world.size())
|
||||
self.assertEqual(symm_mem_world.rank, world.rank())
|
||||
self.assertEqual(symm_mem_subgroup.world_size, world.size() // 2)
|
||||
self.assertEqual(symm_mem_subgroup.rank, world.rank() % subgroup.size())
|
||||
|
||||
t.fill_(world.rank())
|
||||
symm_mem_world.barrier()
|
||||
|
||||
# Observe a peer buffer via the world group
|
||||
peer_rank = (world.rank() + 1) % world.size()
|
||||
buf = symm_mem_world.get_buffer(peer_rank, (64,), torch.float32)
|
||||
self.assertTrue(buf.eq(peer_rank).all())
|
||||
|
||||
# Observe a peer buffer via the subgroup
|
||||
peer_rank = (subgroup.rank() + 1) % subgroup.size()
|
||||
buf = symm_mem_subgroup.get_buffer(peer_rank, (64,), torch.float32)
|
||||
if world.rank() < world.size() // 2:
|
||||
self.assertTrue(buf.eq(peer_rank).all())
|
||||
else:
|
||||
self.assertTrue(buf.eq(peer_rank + world.size() // 2).all())
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
@requires_cuda_p2p_access()
|
||||
class SymmMemAllReduceTest(MultiProcessTestCase):
|
||||
|
|
|
|||
|
|
@ -662,14 +662,17 @@ class _SymmetricMemory:
|
|||
stride: torch.types._size,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
group_name: str,
|
||||
group_name: str | None = None,
|
||||
alloc_id: int | None = None,
|
||||
) -> torch.Tensor: ...
|
||||
@property
|
||||
def rank(self) -> int: ...
|
||||
@property
|
||||
def world_size(self) -> int: ...
|
||||
@staticmethod
|
||||
def rendezvous(tensor: torch.Tensor) -> _SymmetricMemory: ...
|
||||
def rendezvous(
|
||||
tensor: torch.Tensor, group_name: str | None = None
|
||||
) -> _SymmetricMemory: ...
|
||||
def get_buffer(
|
||||
self,
|
||||
rank: int,
|
||||
|
|
|
|||
|
|
@ -95,7 +95,6 @@ class IpcChannel {
|
|||
cmsg->cmsg_type = SCM_RIGHTS;
|
||||
|
||||
if (fd != -1) {
|
||||
// memcpy(CMSG_DATA(cmsg), &fd, sizeof(fd));
|
||||
std::copy(
|
||||
reinterpret_cast<const char*>(&fd),
|
||||
reinterpret_cast<const char*>(&fd) + sizeof(fd),
|
||||
|
|
@ -273,9 +272,29 @@ void map_block(
|
|||
namespace c10d {
|
||||
namespace symmetric_memory {
|
||||
|
||||
AllocationRef::AllocationRef(void* ptr, HandleType handle, size_t block_size, int device_idx)
|
||||
: ptr(ptr), handle(handle), block_size(block_size), device_idx(device_idx) {}
|
||||
|
||||
AllocationRef::~AllocationRef() {
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
// Leak the cuda allocations during static deinitialization
|
||||
if (is_finalizing()) {
|
||||
return;
|
||||
}
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
c10::cuda::CUDAGuard guard(device_idx);
|
||||
C10_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
driver_api->cuMemUnmap_(reinterpret_cast<CUdeviceptr>(ptr), block_size));
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handle));
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
#endif
|
||||
}
|
||||
|
||||
CUDASymmetricMemory::CUDASymmetricMemory(
|
||||
std::vector<HandleType> handles,
|
||||
size_t block_size,
|
||||
std::vector<c10::intrusive_ptr<AllocationRef>> alloc_refs,
|
||||
std::vector<void*> buffers,
|
||||
std::vector<void*> signal_pads,
|
||||
HandleType mc_handle,
|
||||
|
|
@ -284,8 +303,7 @@ CUDASymmetricMemory::CUDASymmetricMemory(
|
|||
int local_device_idx,
|
||||
int rank,
|
||||
int world_size)
|
||||
: handles_(std::move(handles)),
|
||||
block_size_(block_size),
|
||||
: alloc_refs_(std::move(alloc_refs)),
|
||||
buffers_(std::move(buffers)),
|
||||
signal_pads_(std::move(signal_pads)),
|
||||
mc_handle_(mc_handle),
|
||||
|
|
@ -307,29 +325,6 @@ CUDASymmetricMemory::CUDASymmetricMemory(
|
|||
signal_pads_dev_, signal_pads_.data(), arr_size, cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
CUDASymmetricMemory::~CUDASymmetricMemory() {
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
// Leak the cuda allocations during static deinitialization
|
||||
if (is_finalizing()) {
|
||||
return;
|
||||
}
|
||||
c10::cuda::CUDAGuard guard(local_device_idx_);
|
||||
C10_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
for (int r = 0; r < world_size_; ++r) {
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_(
|
||||
reinterpret_cast<CUdeviceptr>(buffers_[r]), block_size_));
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handles_[r]));
|
||||
}
|
||||
c10::cuda::CUDACachingAllocator::raw_delete(buffers_dev_);
|
||||
c10::cuda::CUDACachingAllocator::raw_delete(signal_pads_dev_);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
#endif
|
||||
}
|
||||
|
||||
std::vector<void*> CUDASymmetricMemory::get_buffer_ptrs() {
|
||||
return buffers_;
|
||||
}
|
||||
|
|
@ -592,12 +587,27 @@ int CUDASymmetricMemory::get_world_size() {
|
|||
return world_size_;
|
||||
}
|
||||
|
||||
Block::Block(
|
||||
c10::intrusive_ptr<AllocationRef> alloc_ref,
|
||||
int device_idx,
|
||||
size_t block_size,
|
||||
size_t buffer_size,
|
||||
size_t signal_pad_offset,
|
||||
const std::optional<std::string>& group_name)
|
||||
: alloc_ref(std::move(alloc_ref)),
|
||||
device_idx(device_idx),
|
||||
block_size(block_size),
|
||||
buffer_size(buffer_size),
|
||||
signal_pad_offset(signal_pad_offset),
|
||||
default_group_name(std::move(group_name)) {}
|
||||
|
||||
void* CUDASymmetricMemoryAllocator::alloc(
|
||||
size_t size,
|
||||
int device_idx,
|
||||
const std::string& group_name) {
|
||||
const std::optional<std::string>& group_name) {
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
c10::cuda::CUDAGuard guard(device_idx);
|
||||
device_idx = static_cast<int>(guard.current_device().index());
|
||||
|
||||
CUmemAllocationProp prop = {};
|
||||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
|
|
@ -610,6 +620,7 @@ void* CUDASymmetricMemoryAllocator::alloc(
|
|||
size_t block_size = signal_pad_offset + signal_pad_size;
|
||||
|
||||
size_t granularity;
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemGetAllocationGranularity_(
|
||||
&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
|
||||
block_size = at::round_up(block_size, granularity);
|
||||
|
|
@ -621,11 +632,16 @@ void* CUDASymmetricMemoryAllocator::alloc(
|
|||
void* ptr = nullptr;
|
||||
map_block(&ptr, handle, block_size, device_idx);
|
||||
|
||||
c10::cuda::CUDAGuard guard(device_idx);
|
||||
AT_CUDA_CHECK(cudaMemset(ptr, 0, block_size));
|
||||
|
||||
auto alloc_ref = c10::make_intrusive<AllocationRef>(ptr, handle, block_size, device_idx);
|
||||
auto block = c10::make_intrusive<Block>(
|
||||
handle, device_idx, block_size, size, signal_pad_offset, group_name);
|
||||
std::move(alloc_ref),
|
||||
device_idx,
|
||||
block_size,
|
||||
size,
|
||||
signal_pad_offset,
|
||||
group_name);
|
||||
{
|
||||
std::unique_lock lock(mutex_);
|
||||
ptr_to_block_.emplace(ptr, std::move(block));
|
||||
|
|
@ -638,28 +654,8 @@ void* CUDASymmetricMemoryAllocator::alloc(
|
|||
}
|
||||
|
||||
void CUDASymmetricMemoryAllocator::free(void* ptr) {
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
auto block = find_block(ptr);
|
||||
// Leak the cuda allocations during static deinitialization
|
||||
if (block == nullptr || is_finalizing()) {
|
||||
return;
|
||||
}
|
||||
// Initializing CUDASymmetricMemory with an allocation transfers its
|
||||
// ownership to the CUDASymmetricMemory object.
|
||||
if (block->symm_mem == nullptr) {
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_(
|
||||
reinterpret_cast<CUdeviceptr>(ptr), block->block_size));
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(block->handle));
|
||||
}
|
||||
{
|
||||
std::unique_lock lock(mutex_);
|
||||
ptr_to_block_.erase(ptr);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
#endif
|
||||
std::unique_lock lock(mutex_);
|
||||
ptr_to_block_.erase(ptr);
|
||||
}
|
||||
|
||||
size_t CUDASymmetricMemoryAllocator::get_alloc_size(void* ptr) {
|
||||
|
|
@ -785,7 +781,7 @@ static void init_multicast_for_block(
|
|||
C10_CUDA_DRIVER_CHECK(
|
||||
driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx));
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMulticastBindMem_(
|
||||
mc_handle, 0, block->handle, 0, block->block_size, 0));
|
||||
mc_handle, 0, block->alloc_ref->handle, 0, block->block_size, 0));
|
||||
|
||||
map_block(&mc_addr, mc_handle, block->block_size, block->device_idx);
|
||||
store_barrier(store, rank, world_size);
|
||||
|
|
@ -793,19 +789,36 @@ static void init_multicast_for_block(
|
|||
}
|
||||
|
||||
c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
|
||||
void* ptr) {
|
||||
void* ptr,
|
||||
const std::optional<std::string>& group_name) {
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
auto block = find_block(ptr);
|
||||
if (block == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (block->symm_mem != nullptr) {
|
||||
return block->symm_mem;
|
||||
// The group_name passed to rendezvous() takes precedence over
|
||||
// the default group_name specified during allocation.
|
||||
std::string group_name_;
|
||||
if (group_name.has_value()) {
|
||||
group_name_ = *group_name;
|
||||
} else {
|
||||
if (!block->default_group_name.has_value()) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"CUDASymmetricMemory::rendezvous: `group_name` is neither "
|
||||
"specified during allocation nor passed to rendezvous().");
|
||||
}
|
||||
group_name_ = *block->default_group_name;
|
||||
}
|
||||
|
||||
auto it = block->symm_mems.find(group_name_);
|
||||
if (it != block->symm_mems.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
IpcChannel ipc_channel;
|
||||
auto group_info = get_group_info(block->group_name);
|
||||
auto group_info = get_group_info(group_name_);
|
||||
auto store = group_info.store;
|
||||
int rank = group_info.rank;
|
||||
int world_size = group_info.world_size;
|
||||
|
|
@ -813,7 +826,10 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
|
|||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
int block_fd;
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_(
|
||||
&block_fd, block->handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0));
|
||||
&block_fd,
|
||||
block->alloc_ref->handle,
|
||||
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
|
||||
0));
|
||||
|
||||
auto local_req = RendezvousRequest{
|
||||
.device_idx = block->device_idx,
|
||||
|
|
@ -837,7 +853,7 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
|
|||
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
if (r == rank) {
|
||||
handles[r] = block->handle;
|
||||
handles[r] = block->alloc_ref->handle;
|
||||
buffers[r] = ptr;
|
||||
signal_pads[r] = (void*)((uintptr_t)ptr + block->signal_pad_offset);
|
||||
continue;
|
||||
|
|
@ -861,13 +877,18 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
|
|||
mc_handle, mc_addr, block, ipc_channel, pids, store, rank, world_size);
|
||||
}
|
||||
|
||||
// Initializing CUDASymmetricMemory with an allocation transfers its
|
||||
// ownership to the CUDASymmetricMemory object. So that outstanding
|
||||
// references to the CUDASymmetricMemory object can keep the allocation
|
||||
// alive.
|
||||
block->symm_mem = c10::make_intrusive<CUDASymmetricMemory>(
|
||||
std::move(handles),
|
||||
block->block_size,
|
||||
std::vector<c10::intrusive_ptr<AllocationRef>> alloc_refs;
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
if (r == rank) {
|
||||
alloc_refs.emplace_back(block->alloc_ref);
|
||||
continue;
|
||||
}
|
||||
alloc_refs.push_back(c10::make_intrusive<AllocationRef>(
|
||||
buffers[r], handles[r], block->block_size, block->device_idx));
|
||||
}
|
||||
|
||||
auto symm_mem = c10::make_intrusive<CUDASymmetricMemory>(
|
||||
std::move(alloc_refs),
|
||||
std::move(buffers),
|
||||
std::move(signal_pads),
|
||||
mc_handle,
|
||||
|
|
@ -876,22 +897,14 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
|
|||
block->device_idx,
|
||||
group_info.rank,
|
||||
group_info.world_size);
|
||||
return block->symm_mem;
|
||||
block->symm_mems[group_name_] = symm_mem;
|
||||
return symm_mem;
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
#endif
|
||||
}
|
||||
|
||||
bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) {
|
||||
auto block = find_block(ptr);
|
||||
TORCH_CHECK(
|
||||
block != nullptr,
|
||||
"CUDASymmetricMemoryAllocator::is_rendezvous_completed: input must be allocated ",
|
||||
"via CUDASymmetricMemoryAllocator::alloc");
|
||||
return block->symm_mem != nullptr;
|
||||
}
|
||||
|
||||
bool CUDASymmetricMemoryAllocator::has_multicast_support(int device_idx) {
|
||||
return device_has_multicast_support(device_idx);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,11 +12,27 @@ using HandleType = CUmemGenericAllocationHandle;
|
|||
using HandleType = void*;
|
||||
#endif
|
||||
|
||||
// Resource wrapper that owns a (vaddr, allocation handle) pair. Upon
|
||||
// destruction, it unmaps the vaddr and releases the allocation handle.
|
||||
struct AllocationRef : public c10::intrusive_ptr_target {
|
||||
void* ptr;
|
||||
HandleType handle;
|
||||
size_t block_size;
|
||||
int device_idx;
|
||||
|
||||
AllocationRef(
|
||||
void* ptr,
|
||||
HandleType handle,
|
||||
size_t block_size,
|
||||
int device_idx);
|
||||
|
||||
~AllocationRef();
|
||||
};
|
||||
|
||||
class CUDASymmetricMemory : public SymmetricMemory {
|
||||
public:
|
||||
CUDASymmetricMemory(
|
||||
std::vector<HandleType> handles,
|
||||
size_t block_size,
|
||||
std::vector<c10::intrusive_ptr<AllocationRef>> alloc_refs,
|
||||
std::vector<void*> buffers,
|
||||
std::vector<void*> signal_pads,
|
||||
HandleType mc_handle,
|
||||
|
|
@ -26,7 +42,7 @@ class CUDASymmetricMemory : public SymmetricMemory {
|
|||
int rank,
|
||||
int world_size);
|
||||
|
||||
~CUDASymmetricMemory() override;
|
||||
~CUDASymmetricMemory() override{};
|
||||
|
||||
std::vector<void*> get_buffer_ptrs() override;
|
||||
std::vector<void*> get_signal_pad_ptrs() override;
|
||||
|
|
@ -58,8 +74,7 @@ class CUDASymmetricMemory : public SymmetricMemory {
|
|||
int get_world_size() override;
|
||||
|
||||
private:
|
||||
std::vector<HandleType> handles_;
|
||||
size_t block_size_;
|
||||
std::vector<c10::intrusive_ptr<AllocationRef>> alloc_refs_;
|
||||
std::vector<void*> buffers_;
|
||||
std::vector<void*> signal_pads_;
|
||||
HandleType mc_handle_;
|
||||
|
|
@ -70,43 +85,40 @@ class CUDASymmetricMemory : public SymmetricMemory {
|
|||
int world_size_;
|
||||
void** buffers_dev_;
|
||||
void** signal_pads_dev_;
|
||||
std::optional<std::function<void(void)>> finalizer_;
|
||||
};
|
||||
|
||||
// Metadata associated with each allocation performed by
|
||||
// `CUDASymmetricMemoryAllocator`.
|
||||
struct Block : public c10::intrusive_ptr_target {
|
||||
HandleType handle;
|
||||
c10::intrusive_ptr<AllocationRef> alloc_ref;
|
||||
int device_idx;
|
||||
size_t block_size;
|
||||
size_t buffer_size;
|
||||
size_t signal_pad_offset;
|
||||
std::string group_name;
|
||||
c10::intrusive_ptr<CUDASymmetricMemory> symm_mem = nullptr;
|
||||
std::optional<std::string> default_group_name;
|
||||
std::map<std::string, c10::intrusive_ptr<CUDASymmetricMemory>> symm_mems;
|
||||
|
||||
Block(
|
||||
HandleType handle,
|
||||
c10::intrusive_ptr<AllocationRef> alloc_ref,
|
||||
int device_idx,
|
||||
size_t block_size,
|
||||
size_t buffer_size,
|
||||
size_t signal_pad_offset,
|
||||
std::string group_name)
|
||||
: handle(handle),
|
||||
device_idx(device_idx),
|
||||
block_size(block_size),
|
||||
buffer_size(buffer_size),
|
||||
signal_pad_offset(signal_pad_offset),
|
||||
group_name(std::move(group_name)),
|
||||
symm_mem(nullptr) {}
|
||||
const std::optional<std::string>& group_name);
|
||||
};
|
||||
|
||||
class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator {
|
||||
public:
|
||||
void* alloc(size_t size, int device_idx, const std::string& group_name)
|
||||
override;
|
||||
void* alloc(
|
||||
size_t size,
|
||||
int device_idx,
|
||||
const std::optional<std::string>& group_name) override;
|
||||
|
||||
void free(void* ptr) override;
|
||||
size_t get_alloc_size(void* ptr) override;
|
||||
c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) override;
|
||||
bool is_rendezvous_completed(void* ptr) override;
|
||||
c10::intrusive_ptr<SymmetricMemory> rendezvous(
|
||||
void* ptr,
|
||||
const std::optional<std::string>& group_name) override;
|
||||
bool has_multicast_support(int device_idx) override;
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ static at::Tensor empty_strided_p2p_persistent(
|
|||
c10::IntArrayRef stride,
|
||||
c10::ScalarType dtype,
|
||||
c10::Device device,
|
||||
const std::string& group_name,
|
||||
const std::optional<std::string>& group_name,
|
||||
uint64_t alloc_id) {
|
||||
// Make the allocation fails if a previous allocation with the same alloc_id
|
||||
// is still active.
|
||||
|
|
@ -153,7 +153,7 @@ at::Tensor empty_strided_p2p(
|
|||
c10::IntArrayRef stride,
|
||||
c10::ScalarType dtype,
|
||||
c10::Device device,
|
||||
const std::string& group_name,
|
||||
const std::optional<std::string>& group_name,
|
||||
std::optional<uint64_t> alloc_id) {
|
||||
if (alloc_id.has_value()) {
|
||||
return empty_strided_p2p_persistent(
|
||||
|
|
@ -181,19 +181,10 @@ at::Tensor empty_strided_p2p(
|
|||
}
|
||||
|
||||
TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
|
||||
const at::Tensor& tensor) {
|
||||
const at::Tensor& tensor,
|
||||
const std::optional<std::string>& group_name) {
|
||||
auto allocator = get_allocator(tensor.device().type());
|
||||
return allocator->rendezvous(tensor.storage().data_ptr().get());
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
|
||||
const at::Tensor& tensor) {
|
||||
auto allocator = get_allocator(tensor.device().type());
|
||||
TORCH_CHECK(
|
||||
allocator->is_rendezvous_completed(tensor.data_ptr()),
|
||||
"SymmetricMemory: must invoke rendezvous on a tensor ",
|
||||
"before calling get_symmetric_memory on it");
|
||||
return allocator->rendezvous(tensor.data_ptr());
|
||||
return allocator->rendezvous(tensor.storage().data_ptr().get(), group_name);
|
||||
}
|
||||
|
||||
TORCH_API bool has_multicast_support(
|
||||
|
|
|
|||
|
|
@ -80,12 +80,13 @@ class SymmetricMemoryAllocator : public c10::intrusive_ptr_target {
|
|||
virtual void* alloc(
|
||||
size_t size,
|
||||
int device_idx,
|
||||
const std::string& group_name) = 0;
|
||||
const std::optional<std::string>& group_name) = 0;
|
||||
|
||||
virtual void free(void* ptr) = 0;
|
||||
virtual size_t get_alloc_size(void* ptr) = 0;
|
||||
virtual c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) = 0;
|
||||
virtual bool is_rendezvous_completed(void* ptr) = 0;
|
||||
virtual c10::intrusive_ptr<SymmetricMemory> rendezvous(
|
||||
void* ptr,
|
||||
const std::optional<std::string>& group_name) = 0;
|
||||
virtual bool has_multicast_support(int device_idx) = 0;
|
||||
};
|
||||
|
||||
|
|
@ -138,7 +139,7 @@ TORCH_API at::Tensor empty_strided_p2p(
|
|||
c10::IntArrayRef stride,
|
||||
c10::ScalarType dtype,
|
||||
c10::Device device,
|
||||
const std::string& group_name,
|
||||
const std::optional<std::string>& group_name,
|
||||
std::optional<uint64_t> alloc_id);
|
||||
|
||||
// Establishes symmetric memory access on tensors allocated via
|
||||
|
|
@ -152,12 +153,8 @@ TORCH_API at::Tensor empty_strided_p2p(
|
|||
// The function has a collective semantic and must be invoked simultaneously
|
||||
// from all rendezvous participants.
|
||||
TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
|
||||
const at::Tensor& tensor);
|
||||
|
||||
// Returns the SymmetricMemory object associated with the tensor. It can only
|
||||
// be invoked after rendezvous() but does not need to be invoked collectively.
|
||||
TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
|
||||
const at::Tensor& tensor);
|
||||
const at::Tensor& tensor,
|
||||
const std::optional<std::string>& group_name = std::nullopt);
|
||||
|
||||
TORCH_API bool has_multicast_support(
|
||||
c10::DeviceType device_type,
|
||||
|
|
|
|||
|
|
@ -1057,12 +1057,13 @@ This class does not support ``__members__`` property.)");
|
|||
py::arg("stride"),
|
||||
py::arg("dtype"),
|
||||
py::arg("device"),
|
||||
py::arg("group_name"),
|
||||
py::arg("group_name") = py::none(),
|
||||
py::arg("alloc_id") = py::none())
|
||||
.def_static("rendezvous", &::c10d::symmetric_memory::rendezvous)
|
||||
.def_static(
|
||||
"get_symmetric_memory",
|
||||
&::c10d::symmetric_memory::get_symmetric_memory)
|
||||
"rendezvous",
|
||||
&::c10d::symmetric_memory::rendezvous,
|
||||
py::arg("tensor"),
|
||||
py::arg("group_name") = py::none())
|
||||
.def_static(
|
||||
"has_multicast_support",
|
||||
&::c10d::symmetric_memory::has_multicast_support)
|
||||
|
|
|
|||
|
|
@ -176,8 +176,8 @@ bool IntraNodeComm::rendezvous() {
|
|||
set_group_info(
|
||||
groupName, static_cast<int>(rank_), static_cast<int>(worldSize_), store_);
|
||||
auto allocator = get_allocator(c10::DeviceType::CUDA);
|
||||
symmetricMemoryPtr_ = allocator->alloc(bufferSize_, deviceIdx_, groupName);
|
||||
symmetricMemory_ = allocator->rendezvous(symmetricMemoryPtr_);
|
||||
symmetricMemoryPtr_ = allocator->alloc(bufferSize_, deviceIdx_, std::nullopt);
|
||||
symmetricMemory_ = allocator->rendezvous(symmetricMemoryPtr_, groupName);
|
||||
isInitialized_ = true;
|
||||
return true;
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in a new issue