Most recently used cache management for TorchDynamo (#88076)

Modify the lookup procedure for TorchDynamo caches to keep the head of the single linked list as the most recently used cache entry, which may potentially improve probability for cache hitting.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88076
Approved by: https://github.com/jansel
This commit is contained in:
zyq8709 2022-11-08 18:46:56 +00:00 committed by PyTorch MergeBot
parent 1b5373fc83
commit eaf4fe3d2b

View file

@ -191,6 +191,17 @@ static void destroy_cache_entry(CacheEntry* e) {
free(e);
}
inline static CacheEntry* get_extra(PyCodeObject* code) {
CacheEntry* extra = NULL;
_PyCode_GetExtra((PyObject*)code, extra_index, (void*)&extra);
return extra;
}
inline static void set_extra(PyCodeObject* code, CacheEntry* extra) {
// TODO(jansel): would it be faster to bypass this?
_PyCode_SetExtra((PyObject*)code, extra_index, extra);
}
#ifdef TORCHDYNAMO_DEBUG
inline static const char* name(PyFrameObject* frame) {
DEBUG_CHECK(PyUnicode_Check(frame->f_code->co_name));
@ -216,10 +227,11 @@ static void call_guard_fail_hook(
Py_DECREF(args);
}
static PyCodeObject* lookup(CacheEntry* e, PyObject* f_locals) {
static PyCodeObject* lookup(CacheEntry* e, PyFrameObject *frame, CacheEntry* prev) {
if (e == NULL) {
return NULL;
}
PyObject *f_locals = frame->f_locals;
PyObject* dotzero = PyDict_GetItem(f_locals, dotzerokey);
PyObject* valid = NULL;
if (unlikely(dotzero != NULL)) {
@ -240,12 +252,21 @@ static PyCodeObject* lookup(CacheEntry* e, PyObject* f_locals) {
}
Py_DECREF(valid);
if (valid == Py_True) {
// Keep the head as the most recently used cache entry.
// If the hit cache entry is not the head of the linked list,
// move it to the head
if (prev != NULL) {
CacheEntry* extra = get_extra(frame->f_code);
prev->next = e->next;
e->next = extra;
set_extra(frame->f_code, e);
}
return e->code;
}
if (unlikely(guard_fail_hook != NULL)) {
call_guard_fail_hook(guard_fail_hook, e, f_locals);
}
return lookup(e->next, f_locals);
return lookup(e->next, frame, e);
}
static long cache_size(CacheEntry* e) {
@ -255,17 +276,6 @@ static long cache_size(CacheEntry* e) {
return 1 + cache_size(e->next);
}
inline static CacheEntry* get_extra(PyCodeObject* code) {
CacheEntry* extra = NULL;
_PyCode_GetExtra((PyObject*)code, extra_index, (void*)&extra);
return extra;
}
inline static void set_extra(PyCodeObject* code, CacheEntry* extra) {
// TODO(jansel): would it be faster to bypass this?
_PyCode_SetExtra((PyObject*)code, extra_index, extra);
}
inline static PyObject* eval_custom_code(
PyThreadState* tstate,
PyFrameObject* frame,
@ -358,7 +368,7 @@ static PyObject* _custom_eval_frame(
// we never compile.
if (callback == Py_False) {
DEBUG_TRACE("In run only mode %s", name(frame));
PyCodeObject* cached_code = lookup(extra, frame->f_locals);
PyCodeObject* cached_code = lookup(extra, frame, NULL);
if (cached_code != NULL) {
// used cached version
DEBUG_TRACE("cache hit %s", name(frame));
@ -377,7 +387,7 @@ static PyObject* _custom_eval_frame(
// in the shim.
eval_frame_callback_set(Py_None);
PyCodeObject* cached_code = lookup(extra, frame->f_locals);
PyCodeObject* cached_code = lookup(extra, frame, NULL);
if (cached_code != NULL) {
// used cached version
DEBUG_TRACE("cache hit %s", name(frame));