Improve detection of workspace/non-output allocations in cudagraphs (#99985)

When we run cudagraph trees we are not allowed to have permanent workspace allocations like in cublas because we might need to reclaim that memory for a previous cudagraph recording, and it is memory that is not accounted for in output weakrefs so it does not work with checkpointing. Previously, I would check that we didn't have any additional allocations through snapshotting. This was extremely slow so I had to turn it off.

This PR first does the quick checking to see if we are in an error state, then if we are does the slow logic of creating snapshot. Also turns on history recording so we get a stacktrace of where the bad allocation came from.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99985
Approved by: https://github.com/zdevito
This commit is contained in:
Elias Ellison 2023-04-29 00:07:56 +00:00 committed by PyTorch MergeBot
parent 5d93265cce
commit 3edff6b6ec
7 changed files with 224 additions and 19 deletions

View file

@ -1271,6 +1271,35 @@ class DeviceCachingAllocator {
alloc_trace->clear();
}
bool isHistoryEnabled() {
return record_history;
}
bool checkPoolLiveAllocations(
MempoolId_t mempool_id,
const std::unordered_set<void*>& expected_live_allocations) {
std::unique_lock<std::recursive_mutex> lock(mutex);
PrivatePool* pool;
auto pool_it = graph_pools.find(mempool_id);
TORCH_CHECK(pool_it != graph_pools.end(), "Could not find pool of id");
pool = pool_it->second.get();
size_t allocated_pool_blocks = 0;
for (Block* b : active_blocks) {
if (b->allocated && b->pool->owner_PrivatePool == pool) {
if (!expected_live_allocations.count(b->ptr)) {
return false;
}
allocated_pool_blocks += 1;
}
}
return allocated_pool_blocks == expected_live_allocations.size();
}
void attachOutOfMemoryObserver(OutOfMemoryObserver observer) {
oom_observers_.emplace_back(std::move(observer));
}
@ -3176,6 +3205,20 @@ class NativeCachingAllocator : public CUDAAllocator {
alloc_trace_record_context);
}
bool isHistoryEnabled() override {
int device;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
return device_allocator[device]->isHistoryEnabled();
}
bool checkPoolLiveAllocations(
int device,
MempoolId_t mempool_id,
const std::unordered_set<void*>& expected_live_allocations) override {
return device_allocator[device]->checkPoolLiveAllocations(
mempool_id, expected_live_allocations);
}
void attachOutOfMemoryObserver(OutOfMemoryObserver observer) override {
int device;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));

View file

@ -10,6 +10,7 @@
#include <array>
#include <mutex>
#include <set>
#include <unordered_set>
namespace c10 {
@ -217,7 +218,25 @@ class CUDAAllocator : public Allocator {
MempoolId_t mempool_id) = 0;
virtual void endAllocateStreamToPool(int device, cudaStream_t stream) = 0;
virtual void releasePool(int device, MempoolId_t mempool_id) = 0;
// returns true if the allocated blocks are equal to expected live allocations
virtual bool checkPoolLiveAllocations(
int device,
MempoolId_t mempool_id,
const std::unordered_set<void*>& expected_live_allocations) {
TORCH_CHECK(
false,
name(),
" does not yet support checkPoolLiveAllocations. "
"If you need it, please file an issue describing your use case.");
}
virtual std::shared_ptr<void> getIpcDevPtr(std::string handle) = 0;
virtual bool isHistoryEnabled() {
TORCH_CHECK(
false,
name(),
" does not yet support recordHistory. "
"If you need it, please file an issue describing your use case.");
}
virtual void recordHistory(
bool enabled,
CreateContextFn context_recorder,
@ -355,6 +374,18 @@ inline void recordHistory(
alloc_trace_record_context);
}
inline bool isHistoryEnabled() {
return get()->isHistoryEnabled();
}
inline bool checkPoolLiveAllocations(
int device,
MempoolId_t mempool_id,
const std::unordered_set<void*>& expected_live_allocations) {
return get()->checkPoolLiveAllocations(
device, mempool_id, expected_live_allocations);
}
inline void attachOutOfMemoryObserver(OutOfMemoryObserver observer) {
return get()->attachOutOfMemoryObserver(observer);
}

View file

@ -13,9 +13,11 @@ import torch._dynamo
import torch.nn as nn
from torch._inductor import config
from torch._inductor.cudagraph_trees import cudagraphify_impl as tree_cudagraphify_impl
from torch.testing import FileCheck
from torch.testing._internal.common_utils import (
IS_CI,
IS_LINUX,
IS_WINDOWS,
TEST_WITH_ASAN,
TEST_WITH_ROCM,
@ -727,6 +729,42 @@ if HAS_CUDA and not TEST_WITH_ASAN:
del x
self.assertEqual(all_live_block_count(), 0)
@unittest.skipIf(not IS_LINUX, "cpp contexts are linux only")
@torch._inductor.config.patch("triton.cudagraph_trees_history_recording", True)
def test_workspace_allocation_error(self):
torch._C._cuda_clearCublasWorkspaces()
prev = torch._inductor.cudagraph_trees.clear_cublas_manager
try:
torch._inductor.cudagraph_trees.clear_cublas_manager = (
contextlib.nullcontext
)
@torch.compile()
def foo(x, y):
return x @ x
inps = [torch.rand([400, 400], device="cuda") for _ in range(2)]
thrown = False
try:
foo(*inps)
except Exception as e:
thrown = True
FileCheck().check("at::cuda::getNewWorkspace").check(
"at::cuda::blas::gemm<float>"
).run(str(e))
self.assertTrue(thrown)
finally:
torch._C._cuda_clearCublasWorkspaces()
torch._inductor.cudagraph_trees.clear_cublas_manager = prev
torch._inductor.cudagraph_trees.get_container(
self.device_idx
).tree_manager = None
def test_peristed_output_livenes(self):
@torch.compile
def foo(x):

View file

@ -5631,6 +5631,27 @@ class TestBlockStateAbsorption(TestCase):
inp = torch.rand([1, 3, 255, 255], device="cuda")
self.checkFunction(m, [inp])
def test_check_pool_live_allocations(self):
def foo():
return torch.ones([4], device="cuda")
pool = torch.cuda.graph_pool_handle()
graph, outputs = cudagraphify(foo, [], pool=pool)
index = outputs[0].device.index
def check(live_dps):
return torch._C._cuda_checkPoolLiveAllocations(index, pool, live_dps)
self.assertTrue(check({outputs[0].data_ptr()}))
self.assertFalse(check({outputs[0].data_ptr(), 0}))
self.assertFalse(check(set()))
del outputs
self.assertTrue(check(set()))
def test_allocate_in_thread_to_pool(self):

View file

@ -220,7 +220,10 @@ class triton:
cudagraph_trees = not is_fbcode()
# assertions not on the fast path, steady state
slow_path_cudagraph_asserts = False
slow_path_cudagraph_asserts = True
# TODO - need to debug why this prevents cleanup
cudagraph_trees_history_recording = False
# assertions on the fast path
fast_path_cudagraph_asserts = False

View file

@ -39,10 +39,10 @@ from __future__ import annotations
import contextlib
import dataclasses
import functools
import gc
import itertools
import sys
import threading
import traceback
import warnings
import weakref
from collections import defaultdict
@ -146,6 +146,26 @@ def clear_cublas_manager():
clear_cublass_cache()
@contextlib.contextmanager
def enable_history_recording():
"Turns on history recording in the CUDA Caching Allocator"
enabled = torch._C._cuda_isHistoryEnabled()
try:
if not enabled:
torch.cuda.memory._record_memory_history()
yield
finally:
if not enabled:
torch.cuda.memory._record_memory_history(None)
def get_history_recording():
# TODO - remove, prevents cleanup
if not config.triton.cudagraph_trees_history_recording:
return contextlib.nullcontext()
return enable_history_recording()
class TreeManagerContainer:
"""
Manages the lifetime of the tree manager. Like `PrivatePool` in cuda caching allocator,
@ -516,13 +536,13 @@ class CUDAWarmupNode:
if config.triton.slow_path_cudagraph_asserts and not self.already_warm:
refs = list(self.path_live_weakrefs())
check_memory_pool(self.cuda_graphs_pool, refs)
check_memory_pool(self.device_index, self.cuda_graphs_pool, refs)
with torch.cuda.device(
self.device_index
), clear_cublas_manager(), _use_cuda_memory_pool_manager(
self.device_index, self.cuda_graphs_pool, self.stream
):
), get_history_recording():
out = self.wrapped_function.model(new_inputs)
# sync up stream used in `_use_cuda_memory_pool_manager` - TODO - wait stream instead ?
@ -544,7 +564,7 @@ class CUDAWarmupNode:
new_storages = [
t for t in out_refs if t.data_ptr() not in non_cudagraph_inps
]
check_memory_pool(self.cuda_graphs_pool, new_storages)
check_memory_pool(self.device_index, self.cuda_graphs_pool, new_storages)
return out
@ -959,13 +979,13 @@ class CUDAGraphNode:
for i, elem in enumerate(inputs)
if i not in self.wrapped_function.static_input_idxs
]
check_memory_pool(self.cuda_graphs_pool, memory)
check_memory_pool(self.device, self.cuda_graphs_pool, memory)
with preserve_rng_state(), torch.cuda.device(
self.device
), clear_cublas_manager(), torch.cuda.graph(
self.graph, stream=self.stream, pool=self.cuda_graphs_pool
):
), get_history_recording():
static_outputs = model(inputs)
# running model should reclaim memory
@ -1056,7 +1076,9 @@ class CUDAGraphNode:
self.debug_check_invariants_after_invocation()
if config.triton.slow_path_cudagraph_asserts:
check_memory_pool(self.cuda_graphs_pool, list(self.path_live_weakrefs()))
check_memory_pool(
self.device, self.cuda_graphs_pool, list(self.path_live_weakrefs())
)
def _mark_prior_graph_output_as_aliased(self, index: PathOutputIndex):
"Remove a graph output from the unaliased, cached tensors in an ancestor node"
@ -1413,21 +1435,40 @@ def get_block_addrs(pool_id, live_only=True):
return blocks
def check_memory_pool(pool_id, live_storages_ptrs: List[StorageWeakRefWrapper]):
assert all(isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs)
gc.collect()
def format_tb(caching_allocator_trace):
formatted_traceback = []
MAX_LENGTH = 20
for entry in caching_allocator_trace["frames"][0:MAX_LENGTH]:
formatted_traceback.append(
traceback.FrameSummary(entry["filename"], entry["line"], entry["name"])
)
return "".join(traceback.format_list(formatted_traceback))
def check_memory_pool(device, pool_id, live_storages_ptrs: List[StorageWeakRefWrapper]):
assert all(
isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs
) # noqa: C419
unique_storages = {stor.data_ptr() for stor in live_storages_ptrs if stor()}
# check if there is a divergence first, then do the expensive snapshot call after
# we know it will error
if torch._C._cuda_checkPoolLiveAllocations(device, pool_id, unique_storages):
return
segments = get_cudagraph_segments(pool_id)
allocated_not_in_live_storages = []
allocated_not_in_live_storages = {}
for segment in segments:
addr = segment["address"]
for block in segment["blocks"]:
if block["state"] == "active_allocated":
if addr not in unique_storages:
allocated_not_in_live_storages.append(addr)
allocated_not_in_live_storages[addr] = block
else:
unique_storages.remove(addr)
@ -1438,11 +1479,19 @@ def check_memory_pool(pool_id, live_storages_ptrs: List[StorageWeakRefWrapper]):
lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}",
)
check(
len(allocated_not_in_live_storages) == 0,
lambda: f"These live storage data ptrs are in the cudagraph pool but not "
f"accounted for as an output of cudagraph trees {allocated_not_in_live_storages}",
)
if allocated_not_in_live_storages != 0:
formatted = []
for dp, block in allocated_not_in_live_storages.items():
trace = (
format_tb(block["history"][-1]) if block.get("history", None) else None
)
formatted.append(f"Data Pointer: {dp}, history: \n{trace}")
formatted_s = "\n".join(formatted)
msg = (
f"These live storage data ptrs are in the cudagraph pool but not "
f"accounted for as an output of cudagraph trees: \n\n{formatted_s}"
)
raise RuntimeError(msg)
class ExecutionState(Enum):
@ -1876,7 +1925,9 @@ class CUDAGraphTreeManager:
# Now the live blocks should be exactly equal to the live storages in private pool
if config.triton.slow_path_cudagraph_asserts:
check_memory_pool(self.cuda_graphs_thread_pool, live_storages_wrappers)
check_memory_pool(
self.device_index, self.cuda_graphs_thread_pool, live_storages_wrappers
)
for wrapper in live_storages_wrappers:
assert wrapper()
assert torch._C._has_Standard_Deleter(wrapper())

View file

@ -913,6 +913,10 @@ static void registerCudaDeviceProperties(PyObject* module) {
alloc_trace_max_entries,
alloc_trace_record_context);
});
m.def("_cuda_isHistoryEnabled", []() {
return c10::cuda::CUDACachingAllocator::isHistoryEnabled();
});
}
// We choose to ignore certain blocks that are currently allocated
@ -1148,6 +1152,20 @@ static void registerCudaPluggableAllocator(PyObject* module) {
c10::cuda::CUDACachingAllocator::releasePool(device, mempool_id);
});
m.def(
"_cuda_checkPoolLiveAllocations",
[](int device,
at::cuda::MempoolId_t mempool_id,
const py::set& expected_live_allocations) {
std::unordered_set<void*> allocations;
allocations.reserve(expected_live_allocations.size());
for (auto& elem : expected_live_allocations) {
allocations.insert(reinterpret_cast<void*>(py::cast<size_t>(elem)));
}
return c10::cuda::CUDACachingAllocator::checkPoolLiveAllocations(
device, mempool_id, allocations);
});
m.def(
"_cuda_setCheckpointPoolState",
[](int device,