Implements user buffer registration using MemPool (#133603)

This PR implements user buffer registration and demonstrates NVLink Sharp (NVLS) reductions using a combination of allocation special memory using MemPool and registering it with the nccl buffer registration APIs.

Part of https://github.com/pytorch/pytorch/issues/124807.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133603
Approved by: https://github.com/kwen2501, https://github.com/eqy
This commit is contained in:
Syed Tousif Ahmed 2024-11-20 13:04:02 -08:00 committed by PyTorch MergeBot
parent b44ecd91ba
commit e0482fdf95
9 changed files with 187 additions and 11 deletions

View file

@ -4000,6 +4000,10 @@ int MemPool::use_count() {
return CUDACachingAllocator::getPoolUseCount(device_, id_);
}
c10::DeviceIndex MemPool::device() {
return device_;
}
MempoolId_t MemPool::graph_pool_handle(bool is_user_created) {
if (is_user_created) {
return {0, uid_++};

View file

@ -506,6 +506,7 @@ struct C10_CUDA_API MemPool {
MempoolId_t id();
CUDACachingAllocator::CUDAAllocator* allocator();
int use_count();
c10::DeviceIndex device();
static MempoolId_t graph_pool_handle(bool is_user_created = true);
private:

View file

@ -46,6 +46,7 @@ from torch.testing._internal.common_distributed import (
init_multigpu_helper,
MultiProcessTestCase,
requires_gloo,
requires_multicast_support,
requires_nccl,
requires_nccl_version,
skip_if_lt_x_gpu,
@ -67,6 +68,7 @@ from torch.testing._internal.common_utils import (
TEST_WITH_ROCM,
TestCase,
)
from torch.utils.cpp_extension import load_inline
if TEST_WITH_DEV_DBG_ASAN:
@ -2961,6 +2963,105 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
) from e
class NcclUserBufferRegistrationTest(MultiProcessTestCase):
def createNcclAllocator(self):
nccl_allocator_source = """
#include <torch/extension.h>
#include <nccl.h>
#include <iostream>
extern "C" {
// Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
C10_EXPORT void* nccl_alloc(size_t size, int device, void* stream) {
std::cout << "Using ncclMemAlloc" << std::endl;
void* ptr;
ncclResult_t err = ncclMemAlloc(&ptr, size);
return ptr;
}
C10_EXPORT void nccl_free(void* ptr, size_t size, int device, void* stream) {
std::cout << "Using ncclMemFree" << std::endl;
ncclResult_t err = ncclMemFree(ptr);
}
}
"""
nccl_allocator_libname = "nccl_allocator"
nccl_allocator = load_inline(
name=nccl_allocator_libname,
cpp_sources=nccl_allocator_source,
with_cuda=True,
extra_ldflags=["-lnccl"],
is_python_module=False,
keep_intermediates=False,
verbose=True,
)
return nccl_allocator
def setUp(self):
super().setUp()
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
# that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
nccl_debug_file = tempfile.NamedTemporaryFile()
os.environ["NCCL_ALGO"] = "NVLS"
os.environ["NCCL_DEBUG"] = "INFO"
os.environ["NCCL_DEBUG_SUBSYS"] = "NVLS"
os.environ["NCCL_DEBUG_FILE"] = nccl_debug_file.name
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
@requires_nccl()
@requires_nccl_version((2, 19), "Need NCCL 2.19 for user buffer registration")
@skip_if_lt_x_gpu(4)
@requires_multicast_support()
def test_nccl_user_buffer_registration(self):
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(
backend="nccl", rank=self.rank, world_size=self.world_size, store=store
)
device = torch.device(f"cuda:{self.rank}")
torch.cuda.set_device(self.rank)
pg = c10d.distributed_c10d._get_default_group()
backend = pg._get_backend(torch.device(device))
allocator_path = self.createNcclAllocator()
allocator = torch.cuda.memory.CUDAPluggableAllocator(
allocator_path,
"nccl_alloc",
"nccl_free",
)
pool = torch.cuda.MemPool(allocator.allocator())
# allocate memory with ncclMemAlloc
with torch.cuda.use_mem_pool(pool):
tensor = torch.arange(1024 * 1024 * 2, device=device)
# register buffers to NCCL
backend.register_mem_pool(pool)
# allreduce now should use NVIDIA Switches
pg.allreduce(tensor).wait()
torch.cuda.synchronize(device=device)
# de-register buffers from NCCL
backend.deregister_mem_pool(pool)
# clean up memory
del tensor, pool
with open(os.environ["NCCL_DEBUG_FILE"]) as f:
nccl_debug_file_content = f.read()
# if buffers were registered and NVLS reduction ran, NCCL_DEBUG
# should show "local-registered" in stdout
self.assertRegex(nccl_debug_file_content, "local-registered")
class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
@property
def device(self):

View file

@ -23,6 +23,7 @@ from torch.distributed._symmetric_memory import (
from torch.testing._internal.common_cuda import _get_torch_cuda_version, SM90OrLater
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
requires_multicast_support,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
@ -57,17 +58,6 @@ def requires_cuda_p2p_access():
)
def requires_multicast_support():
has_multicast_support = (
torch.cuda.is_available()
and _SymmetricMemory.has_multicast_support(DeviceType.CUDA, 0)
)
return skip_but_pass_in_sandcastle_if(
not has_multicast_support,
"multicast support is not available",
)
@instantiate_parametrized_tests
@requires_cuda_p2p_access()
class SymmetricMemoryTest(MultiProcessTestCase):

View file

@ -578,6 +578,8 @@ class ProcessGroupNCCL(Backend):
def _set_default_timeout(self, timeout) -> None: ...
def _shutdown(self) -> None: ...
def perform_nocolor_split(self, device: torch.device) -> None: ...
def register_mem_pool(self, pool: torch.cuda.MemPool) -> None: ...
def deregister_mem_pool(self, pool: torch.cuda.MemPool) -> None: ...
def comm_split_count(self) -> int: ...
def _add_ephemeral_timeout(self, timeout: timedelta) -> None: ...
def abort(self) -> None: ...

View file

@ -1113,6 +1113,59 @@ bool ProcessGroupNCCL::isInitialized() {
return initialized;
}
void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) {
const auto key = std::to_string(pool->device());
auto device = at::Device(at::DeviceType::CUDA, pool->device());
LOG(INFO) << logPrefix()
<< "Performing NCCL user buffer registration for all buffers in "
<< "MemPool: " << pool->id() << ", device index: " << key
<< ", i am " << this;
auto ncclComm = getNCCLComm(key);
if (ncclComm == nullptr) {
// HACK: currently we are using this function for NVLS
// reductions, and that's why using OpType::ALLREDUCE.
// If we end up using this API for zero-copy P2P, we might
// need to refactor and account for different OpType.
ncclComm = initNCCLComm(key, device, OpType::ALLREDUCE);
}
TORCH_INTERNAL_ASSERT(ncclComm != nullptr);
auto ctx = c10::cuda::MemPoolContext(pool);
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot();
for (const auto& segmentInfo : snapshot.segments) {
TORCH_INTERNAL_ASSERT(
segmentInfo.device == pool->device(),
"Mismatch between CUDA memory segment device and pool's device");
ncclComm->registerSegment(
reinterpret_cast<void*>(segmentInfo.address), segmentInfo.total_size);
}
}
void ProcessGroupNCCL::deregisterMemPool(c10::cuda::MemPool* pool) {
const auto key = std::to_string(pool->device());
auto device = at::Device(at::DeviceType::CUDA, pool->device());
LOG(INFO) << logPrefix()
<< "Performing NCCL user buffer deregistration for all buffers in "
<< "MemPool: " << pool->id() << ", device index: " << key
<< ", i am " << this;
auto ncclComm = getNCCLComm(key);
if (ncclComm == nullptr) {
// HACK: currently we are using this function for NVLS
// reductions, and that's why using OpType::ALLREDUCE.
// If we end up using this API for zero-copy P2P, we might
// need to refactor and account for different OpType.
ncclComm = initNCCLComm(key, device, OpType::ALLREDUCE);
}
TORCH_INTERNAL_ASSERT(ncclComm != nullptr);
auto ctx = c10::cuda::MemPoolContext(pool);
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot();
for (const auto& segmentInfo : snapshot.segments) {
TORCH_INTERNAL_ASSERT(
segmentInfo.device == pool->device(),
"Mismatch between CUDA memory segment device and pool's device");
ncclComm->deregisterSegment(reinterpret_cast<void*>(segmentInfo.address));
}
}
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> ProcessGroupNCCL::
initIntraNodeComm() {
using IntraNodeComm = intra_node_comm::IntraNodeComm;

View file

@ -757,6 +757,14 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// If all comms on this PG are fully initialized, return true.
bool isInitialized();
// Performs NCCL user buffer registration for all buffers in
// the given MemPool
void registerMemPool(c10::cuda::MemPool* pool);
// Performs NCCL user buffer de-registration for all buffers in
// the given MemPool
void deregisterMemPool(c10::cuda::MemPool* pool);
// This method adds a temporary extension for the timeout period,
// applying to all collectives between the calling of this API and
// the completion of the first collective on the GPU. While this feature

View file

@ -2958,6 +2958,10 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
.def(
"perform_nocolor_split",
&::c10d::ProcessGroupNCCL::performNocolorSplit)
.def("register_mem_pool", &::c10d::ProcessGroupNCCL::registerMemPool)
.def(
"deregister_mem_pool",
&::c10d::ProcessGroupNCCL::deregisterMemPool)
.def(
"abort",
&::c10d::ProcessGroupNCCL::abort,

View file

@ -29,6 +29,8 @@ import torch
import torch._dynamo.test_case
import torch.cuda.nccl
import torch.distributed as c10d
from torch._C._autograd import DeviceType
from torch._C._distributed_c10d import _SymmetricMemory
import torch.nn as nn
from torch.testing._internal.common_utils import (
FILE_SCHEMA,
@ -348,6 +350,17 @@ def requires_mpi():
)
def requires_multicast_support():
has_multicast_support = (
torch.cuda.is_available()
and _SymmetricMemory.has_multicast_support(DeviceType.CUDA, 0)
)
return skip_but_pass_in_sandcastle_if(
not has_multicast_support,
"multicast support is not available",
)
def skip_if_rocm_multiprocess(func):
"""Skips a test for ROCm"""
func.skip_if_rocm_multiprocess = True