mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Implements torch.cuda.MemPool() API (#131152)
In this PR: - Pool id creation logic is refactored and moved to a MemPool class. `graph_pool_handle()` API now uses `torch.cuda.MemPool()` to get a unique id for a pool. Existing tests should cover this change. - MemPool holds a pointer to a CUDAAllocator as proposed in https://github.com/pytorch/pytorch/issues/124807#issuecomment-2077506997. Tests are added to show usage with CUDAPluggableAllocator. - MemPoolContext API makes a mempool active. Tests are added to show usage of this API. This API will be used in CUDACachingAllocator to route allocations to a user provided allocator. See draft here: https://github.com/pytorch/pytorch/pull/125722/ Pull Request resolved: https://github.com/pytorch/pytorch/pull/131152 Approved by: https://github.com/eqy, https://github.com/ezyang
This commit is contained in:
parent
4e966e8a1c
commit
7c89ec0f7c
13 changed files with 303 additions and 24 deletions
|
|
@ -17,23 +17,10 @@ static bool _cuda_graphs_debug = false;
|
|||
constexpr int kSynchronizeBusyWaitMillis = 10;
|
||||
|
||||
MempoolId_t graph_pool_handle() {
|
||||
// uuid count starts at 1. 0 is reserved to mean "wasn't set by graph_pool_handle".
|
||||
static std::atomic<CaptureId_t> uid{1};
|
||||
// Sets just the second value, to distinguish it from MempoolId_ts created from
|
||||
// cudaStreamGetCaptureInfo id_s in capture_begin.
|
||||
return {0, uid++};
|
||||
}
|
||||
|
||||
|
||||
// Get the expected id of a capture sequence so that we can call beginAllocateStreamToPool
|
||||
// before starting a graph capture
|
||||
CaptureId_t capture_sequence_id() {
|
||||
// id starts at 1:
|
||||
// Ensures uuid count starts at 1. 0 is reserved to mean "not set by cudaStreamGetCaptureInfo".
|
||||
// (But how do we know GetCaptureInfo never sets id_ to 0? Because that's the current behavior,
|
||||
// and I asked cuda devs to keep it that way, and they agreed.)
|
||||
static std::atomic<CaptureId_t> uuid{1};
|
||||
return uuid++;
|
||||
auto new_pool = c10::cuda::MemPool();
|
||||
return new_pool.id();
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -118,8 +105,6 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
|
|||
capture_stream_ = stream;
|
||||
capture_dev_ = c10::cuda::current_device();
|
||||
|
||||
id_ = capture_sequence_id();
|
||||
|
||||
if (pool.first != 0 || pool.second != 0) {
|
||||
// Either value being nonzero means the user supplied a pool to share.
|
||||
// But only one should be nonzero.
|
||||
|
|
@ -128,9 +113,11 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
|
|||
TORCH_INTERNAL_ASSERT(!(pool.first && pool.second));
|
||||
mempool_id_ = pool;
|
||||
} else {
|
||||
// User did not ask us to share a mempool. Use our own id_ as our mempool_id_.
|
||||
// User did not ask us to share a mempool. Create graph pool handle using is_user_created=false.
|
||||
// Sets just the first value, to distinguish it from MempoolId_ts created by graph_pool_handle().
|
||||
mempool_id_ = {id_, 0};
|
||||
auto mempool = c10::cuda::MemPool({}, false);
|
||||
mempool_id_ = mempool.id();
|
||||
TORCH_INTERNAL_ASSERT(mempool_id_.first > 0);
|
||||
}
|
||||
|
||||
// Addendum: beginAllocateStreamToPool is now called before cudaStreamBeginCapture to prevent an
|
||||
|
|
@ -161,7 +148,6 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
|
|||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &capture_id_));
|
||||
TORCH_INTERNAL_ASSERT(status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(id_ > 0);
|
||||
}
|
||||
|
||||
void CUDAGraph::capture_end() {
|
||||
|
|
|
|||
|
|
@ -52,10 +52,6 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
|
|||
// Set to true in capture_end if cudaGraphInstantiate succeeded
|
||||
bool has_graph_exec_ = false;
|
||||
|
||||
// uuid of this instance's current capture, used to
|
||||
// specify the pool.
|
||||
CaptureId_t id_;
|
||||
|
||||
// the ID assigned by cuda during graph capture,
|
||||
// used to identify when a stream is participating in capture
|
||||
CaptureId_t capture_id_ = -1;
|
||||
|
|
|
|||
|
|
@ -770,6 +770,7 @@ libtorch_python_cuda_core_sources = [
|
|||
"torch/csrc/cuda/python_comm.cpp",
|
||||
"torch/csrc/cuda/Stream.cpp",
|
||||
"torch/csrc/cuda/Graph.cpp",
|
||||
"torch/csrc/cuda/MemPool.cpp",
|
||||
"torch/csrc/cuda/shared/cudart.cpp",
|
||||
"torch/csrc/cuda/shared/nvtx.cpp",
|
||||
"torch/csrc/cuda/utils.cpp",
|
||||
|
|
|
|||
|
|
@ -3596,3 +3596,58 @@ BackendStaticInitializer backend_static_initializer;
|
|||
} // namespace cuda::CUDACachingAllocator
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace c10::cuda {
|
||||
|
||||
// uid_ is incremented when a user creates a MemPool,
|
||||
// for example: using graph_pool_handle() or c10::cuda::MemPool().
|
||||
//
|
||||
// uuid_ is incremented when CUDAGraph creates a MemPool
|
||||
// as a result of a user not providing a pool.
|
||||
//
|
||||
// MempoolId_t of {0, 0} is used to denote when no MemPool has been
|
||||
// passed to a function, either by user or CUDAGraphs. For example,
|
||||
// default value of MempoolId_t for capture_begin function is {0, 0}.
|
||||
// That's why uid_ and uuid_ start at 1.
|
||||
std::atomic<CaptureId_t> MemPool::uid_{1};
|
||||
std::atomic<CaptureId_t> MemPool::uuid_{1};
|
||||
|
||||
MemPool::MemPool(
|
||||
CUDACachingAllocator::CUDAAllocator* allocator,
|
||||
bool is_user_created)
|
||||
: allocator_(allocator), is_user_created_(is_user_created) {
|
||||
if (is_user_created_) {
|
||||
id_ = {0, uid_++};
|
||||
} else {
|
||||
id_ = {uuid_++, 0};
|
||||
}
|
||||
}
|
||||
|
||||
MempoolId_t MemPool::id() {
|
||||
return id_;
|
||||
}
|
||||
|
||||
CUDACachingAllocator::CUDAAllocator* MemPool::allocator() {
|
||||
return allocator_;
|
||||
}
|
||||
|
||||
// Note that active_mempool_ is a global variable here
|
||||
// and not inside MemPoolContext class, because in windows we
|
||||
// can't use __declspec(dllexport) and __declspec(thread)
|
||||
// together: https://stackoverflow.com/a/50967977
|
||||
static thread_local MemPool* active_mempool_ = nullptr;
|
||||
|
||||
MemPoolContext::MemPoolContext(MemPool* mempool)
|
||||
: prev_mempool_(active_mempool_) {
|
||||
active_mempool_ = mempool;
|
||||
}
|
||||
|
||||
MemPoolContext::~MemPoolContext() {
|
||||
active_mempool_ = prev_mempool_;
|
||||
}
|
||||
|
||||
MemPool* MemPoolContext::getActiveMemPool() {
|
||||
return active_mempool_;
|
||||
}
|
||||
|
||||
} // namespace c10::cuda
|
||||
|
|
|
|||
|
|
@ -513,3 +513,51 @@ inline void enablePeerAccess(
|
|||
}
|
||||
|
||||
} // namespace c10::cuda::CUDACachingAllocator
|
||||
|
||||
namespace c10::cuda {
|
||||
|
||||
// MemPool represents a pool of memory in a caching allocator. Currently,
|
||||
// it's just the ID of the pool object maintained in the CUDACachingAllocator.
|
||||
//
|
||||
// An allocator pointer can be passed to the MemPool to define how the
|
||||
// allocations should be done in the pool. For example: using a different
|
||||
// system allocator such as ncclMemAlloc.
|
||||
struct C10_CUDA_API MemPool {
|
||||
MemPool(
|
||||
CUDACachingAllocator::CUDAAllocator* allocator = nullptr,
|
||||
bool is_user_created = true);
|
||||
|
||||
MempoolId_t id();
|
||||
CUDACachingAllocator::CUDAAllocator* allocator();
|
||||
|
||||
private:
|
||||
static std::atomic<CaptureId_t> uid_;
|
||||
static std::atomic<CaptureId_t> uuid_;
|
||||
CUDACachingAllocator::CUDAAllocator* allocator_;
|
||||
bool is_user_created_;
|
||||
MempoolId_t id_;
|
||||
};
|
||||
|
||||
// MemPoolContext holds the currently active pool and stashes the previous
|
||||
// pool. On deletion it makes the previous pool active.
|
||||
struct C10_CUDA_API MemPoolContext {
|
||||
MemPoolContext(MemPool* mempool);
|
||||
|
||||
~MemPoolContext();
|
||||
|
||||
// getActiveMemPool() can be used to get the currently active pool.
|
||||
// For instance: in CUDACachingAllocator, we can route allocations
|
||||
// to a user provided allocator, by doing:
|
||||
//
|
||||
// auto active_pool = MemPoolContext::getActiveMemPool();
|
||||
// if (active_pool && active_pool->allocator()) {
|
||||
// ptr = active_pool->allocator()->raw_alloc(size);
|
||||
// }
|
||||
//
|
||||
static MemPool* getActiveMemPool();
|
||||
|
||||
private:
|
||||
MemPool* prev_mempool_;
|
||||
};
|
||||
|
||||
} // namespace c10::cuda
|
||||
|
|
|
|||
|
|
@ -2166,6 +2166,8 @@ coverage_ignore_classes = [
|
|||
"EventHandler",
|
||||
"SynchronizationError",
|
||||
"UnsynchronizedAccessError",
|
||||
# torch.cuda.memory
|
||||
"MemPoolContext",
|
||||
# torch.distributed.elastic.multiprocessing.errors
|
||||
"ChildFailedError",
|
||||
"ProcessFailure",
|
||||
|
|
|
|||
|
|
@ -120,6 +120,8 @@ Memory management
|
|||
get_allocator_backend
|
||||
CUDAPluggableAllocator
|
||||
change_current_allocator
|
||||
MemPool
|
||||
MemPoolContext
|
||||
.. FIXME The following doesn't seem to exist. Is it supposed to?
|
||||
https://github.com/pytorch/pytorch/issues/27785
|
||||
.. autofunction:: reset_max_memory_reserved
|
||||
|
|
|
|||
|
|
@ -4996,6 +4996,101 @@ class TestBlockStateAbsorption(TestCase):
|
|||
self.assertEqual(rc, "False", "Triton was imported when importing torch!")
|
||||
|
||||
|
||||
class TestMemPool(TestCase):
|
||||
def test_mempool_id(self):
|
||||
pool1 = torch.cuda.graph_pool_handle()
|
||||
pool2 = torch.cuda.MemPool().id
|
||||
|
||||
# first value of id in a user created pool is always zero
|
||||
self.assertEqual(pool1[0] == 0, pool2[0] == 0)
|
||||
|
||||
# each call to torch.cuda.graph_pool_handle() or torch.cuda.MemPool()
|
||||
# increments the id
|
||||
self.assertTrue(abs(pool2[1] - pool1[1]) > 0)
|
||||
|
||||
def test_mempool_with_allocator(self):
|
||||
pool = torch.cuda.MemPool()
|
||||
|
||||
# MemPool doesn't have an allocator by default
|
||||
self.assertEqual(pool.allocator, None)
|
||||
|
||||
from torch.utils.cpp_extension import load_inline
|
||||
|
||||
dummy_allocator_source = """
|
||||
extern "C" {
|
||||
void* dummy_alloc(size_t size, int device, void* stream) { return nullptr; }
|
||||
void dummy_free(void* ptr) { }
|
||||
}
|
||||
"""
|
||||
dummy_allocator_libname = "dummy_allocator"
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
dummy_allocator = load_inline(
|
||||
name=dummy_allocator_libname,
|
||||
cpp_sources=dummy_allocator_source,
|
||||
is_python_module=False,
|
||||
build_directory=tempdir,
|
||||
)
|
||||
allocator = torch.cuda.memory.CUDAPluggableAllocator(
|
||||
os.path.join(tempdir, f"{dummy_allocator_libname}.so"),
|
||||
"dummy_alloc",
|
||||
"dummy_free",
|
||||
)
|
||||
pool = torch.cuda.MemPool(allocator.allocator())
|
||||
|
||||
# pool should point to the same allocator as the one passed into it
|
||||
self.assertEqual(allocator.allocator(), pool.allocator)
|
||||
|
||||
def test_mempool_context(self):
|
||||
active_pool = torch.cuda.MemPoolContext.active_pool()
|
||||
|
||||
# there is no active pool if none was made active
|
||||
self.assertEqual(active_pool, None)
|
||||
|
||||
pool = torch.cuda.MemPool()
|
||||
ctx = torch.cuda.MemPoolContext(pool)
|
||||
active_pool = torch.cuda.MemPoolContext.active_pool()
|
||||
|
||||
# pool was made active
|
||||
self.assertEqual(active_pool, pool)
|
||||
|
||||
del ctx
|
||||
active_pool = torch.cuda.MemPoolContext.active_pool()
|
||||
|
||||
# ctx was deleted, so active pool is the previous one
|
||||
self.assertEqual(active_pool, None)
|
||||
|
||||
def test_mempool_multithread(self):
|
||||
pool_ids = []
|
||||
active_pool_ids = []
|
||||
|
||||
def create_mempool_and_make_active():
|
||||
pool = torch.cuda.MemPool()
|
||||
pool_ids.extend([pool.id])
|
||||
|
||||
ctx = torch.cuda.MemPoolContext(pool)
|
||||
active_pool = torch.cuda.MemPoolContext.active_pool()
|
||||
active_pool_ids.extend([active_pool.id])
|
||||
del ctx
|
||||
|
||||
num_threads = 4
|
||||
threads = [
|
||||
threading.Thread(target=create_mempool_and_make_active)
|
||||
for t in range(num_threads)
|
||||
]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# each thread should create a unique mempool, since
|
||||
# mempool id creation is atomic
|
||||
self.assertEqual(len(set(pool_ids)), 4)
|
||||
|
||||
# each thread should have different active mempool, since
|
||||
# the pointer to the mempool is thread local
|
||||
self.assertEqual(len(set(active_pool_ids)), 4)
|
||||
|
||||
|
||||
class TestCudaOptims(TestCase):
|
||||
# These tests will be instantiate with instantiate_device_type_tests
|
||||
# to apply the new OptimizerInfo structure.
|
||||
|
|
|
|||
|
|
@ -2059,6 +2059,19 @@ class _CUDAGraph:
|
|||
def enable_debug_mode(self) -> None: ...
|
||||
def debug_dump(self, debug_path: str) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/cuda/MemPool.cpp
|
||||
class _MemPool:
|
||||
def __init__(self, allocator: Optional[_cuda_CUDAAllocator] = None, is_user_created: _bool = True) -> None: ...
|
||||
@property
|
||||
def id(self) -> Tuple[_int, _int]: ...
|
||||
@property
|
||||
def allocator(self) -> Optional[_cuda_CUDAAllocator]: ...
|
||||
|
||||
class _MemPoolContext:
|
||||
def __init__(self, pool: _MemPool) -> None: ...
|
||||
@staticmethod
|
||||
def active_pool() -> Optional[_MemPool]: ...
|
||||
|
||||
def _cuda_isCurrentStreamCapturing() -> _bool: ...
|
||||
def _graph_pool_handle() -> Tuple[_int, _int]: ...
|
||||
|
||||
|
|
|
|||
|
|
@ -1551,6 +1551,7 @@ static PyMethodDef TorchMethods[] = { // NOLINT
|
|||
void THCPStream_init(PyObject* module);
|
||||
void THCPEvent_init(PyObject* module);
|
||||
void THCPGraph_init(PyObject* module);
|
||||
void THCPMemPool_init(PyObject* module);
|
||||
|
||||
#ifdef USE_CUDA
|
||||
PyMethodDef* THCPModule_methods();
|
||||
|
|
@ -1708,6 +1709,7 @@ PyObject* initModule() {
|
|||
THCPStream_init(module);
|
||||
THCPEvent_init(module);
|
||||
THCPGraph_init(module);
|
||||
THCPMemPool_init(module);
|
||||
#endif
|
||||
|
||||
#ifdef USE_XPU
|
||||
|
|
|
|||
21
torch/csrc/cuda/MemPool.cpp
Normal file
21
torch/csrc/cuda/MemPool.cpp
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
|
||||
template <typename T>
|
||||
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
|
||||
|
||||
void THCPMemPool_init(PyObject* module) {
|
||||
auto torch_C_m = py::handle(module).cast<py::module>();
|
||||
shared_ptr_class_<::c10::cuda::MemPool>(torch_C_m, "_MemPool")
|
||||
.def(py::init<c10::cuda::CUDACachingAllocator::CUDAAllocator*, bool>())
|
||||
.def_property_readonly("id", &::c10::cuda::MemPool::id)
|
||||
.def_property_readonly("allocator", &::c10::cuda::MemPool::allocator);
|
||||
shared_ptr_class_<::c10::cuda::MemPoolContext>(torch_C_m, "_MemPoolContext")
|
||||
.def(py::init<c10::cuda::MemPool*>())
|
||||
.def_static(
|
||||
"active_pool", &::c10::cuda::MemPoolContext::getActiveMemPool);
|
||||
}
|
||||
|
|
@ -1621,6 +1621,8 @@ __all__ = [
|
|||
"memory_stats_as_nested_dict",
|
||||
"memory_summary",
|
||||
"memory_usage",
|
||||
"MemPool",
|
||||
"MemPoolContext",
|
||||
"temperature",
|
||||
"power_draw",
|
||||
"clock_rate",
|
||||
|
|
|
|||
|
|
@ -51,6 +51,8 @@ __all__ = [
|
|||
"get_allocator_backend",
|
||||
"CUDAPluggableAllocator",
|
||||
"change_current_allocator",
|
||||
"MemPool",
|
||||
"MemPoolContext",
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -59,6 +61,14 @@ if not hasattr(torch._C, "_cuda_CUDAAllocator"):
|
|||
torch._C.__dict__["_cuda_CUDAAllocator"] = _dummy_type("_cuda_CUDAAllocator")
|
||||
|
||||
|
||||
if not hasattr(torch._C, "_MemPool"):
|
||||
# Define dummy base classes
|
||||
torch._C.__dict__["_MemPool"] = _dummy_type("_MemPool")
|
||||
torch._C.__dict__["_MemPoolContext"] = _dummy_type("_MemPoolContext")
|
||||
|
||||
from torch._C import _cuda_CUDAAllocator, _MemPool, _MemPoolContext # noqa: F401
|
||||
|
||||
|
||||
def _host_allocator():
|
||||
_lazy_init()
|
||||
return torch._C._cuda_cudaHostAllocator()
|
||||
|
|
@ -946,3 +956,49 @@ def _get_current_allocator() -> _CUDAAllocator:
|
|||
See :ref:`cuda-memory-management` for details on creating and using a custom allocator
|
||||
"""
|
||||
return _CUDAAllocator(torch._C._cuda_getAllocator())
|
||||
|
||||
|
||||
class MemPool(_MemPool):
|
||||
r"""MemPool represents a pool of memory in a caching allocator. Currently,
|
||||
it's just the ID of the pool object maintained in the CUDACachingAllocator.
|
||||
|
||||
Args:
|
||||
allocator(torch._C._cuda_CUDAAllocator, optional): a
|
||||
torch._C._cuda_CUDAAllocator object that can be used to
|
||||
define how memory gets allocated in the pool. If :attr:`allocator`
|
||||
is ``None`` (default), memory allocation follows the default/
|
||||
current configuration of the CUDACachingAllocator.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, allocator: Optional[_cuda_CUDAAllocator] = None):
|
||||
super().__init__(allocator, True)
|
||||
|
||||
@property
|
||||
def id(self) -> Tuple[int, int]:
|
||||
r"""Returns the ID of this pool as a tuple of two ints."""
|
||||
return super().id
|
||||
|
||||
@property
|
||||
def allocator(self) -> Optional[_cuda_CUDAAllocator]:
|
||||
r"""Returns the allocator this MemPool routes allocations to"""
|
||||
return super().allocator
|
||||
|
||||
|
||||
class MemPoolContext(_MemPoolContext):
|
||||
r"""MemPoolContext holds the currently active pool and stashes the previous
|
||||
pool. On deletion it makes the previous pool active.
|
||||
|
||||
Args:
|
||||
pool(torch.cuda.MemPool): a MemPool object to be made active so that
|
||||
allocations route to this pool.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, pool: MemPool):
|
||||
super().__init__(pool)
|
||||
|
||||
@staticmethod
|
||||
def active_pool() -> Optional[_MemPool]:
|
||||
r"""Returns the active MemPool"""
|
||||
return _MemPoolContext.active_pool()
|
||||
|
|
|
|||
Loading…
Reference in a new issue