mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Profiler] Defer recording startup python events (take 2) (#91684)
This is my commandeer of https://github.com/pytorch/pytorch/pull/82154 with a couple extra fixes.
The high level idea is that when we start profiling we see python frames which are currently executing, but we don't know what system TID created them. So instead we defer the TID assignment, and then during post processing we peer into the future and use the system TID *of the next* call on that Python TID.
As an aside, it turns out that CPython does some bookkeeping (ee821dcd39/Include/cpython/pystate.h (L159-L165), thanks @dzhulgakov for the pointer), but you'd have to do some extra work at runtime to know how to map their TID to ours so for now I'm going to stick to what I can glean from post processing alone.
As we start observing more threads it becomes more important to be principled about how we start up and shut down. (Since threads may die while the profiler is running.) #82154 had various troubles with segfaults that wound up being related to accessing Python thread pointers which were no longer alive. I've tweaked the startup and shutdown interaction with the CPython interpreter and it should be safer now.
Differential Revision: [D42336292](https://our.internmc.facebook.com/intern/diff/D42336292/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91684
Approved by: https://github.com/chaekit
This commit is contained in:
parent
8d45f555d7
commit
d09cd15216
2 changed files with 267 additions and 35 deletions
|
|
@ -7,6 +7,7 @@ import os
|
|||
import re
|
||||
import tempfile
|
||||
import textwrap
|
||||
import threading
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
import weakref
|
||||
|
|
@ -57,6 +58,8 @@ from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
|||
from torch.testing._internal.common_device_type import skipCUDAVersionIn
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
TemporaryDirectoryName,
|
||||
TemporaryFileName,
|
||||
|
|
@ -478,6 +481,7 @@ class TestExecutionGraph(TestCase):
|
|||
assert found_root_node
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TestProfiler(TestCase):
|
||||
|
||||
@unittest.skipIf(TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite.")
|
||||
|
|
@ -549,6 +553,161 @@ class TestProfiler(TestCase):
|
|||
|
||||
torch._C._set_graph_executor_optimize(prev_opt)
|
||||
|
||||
@parametrize(
|
||||
"name,thread_spec",
|
||||
{
|
||||
"basic": ((False, False),),
|
||||
"multiple_preexisting": ((False, False), ) * 2,
|
||||
"open_in_scope": ((True, False),),
|
||||
"close_in_scope": ((False, True),),
|
||||
"complex": (
|
||||
# Large number of background threads
|
||||
(False, False),
|
||||
(False, False),
|
||||
(False, False),
|
||||
(False, False),
|
||||
|
||||
# some of which finish during profiling
|
||||
(False, True),
|
||||
(False, True),
|
||||
|
||||
# And the profiled section is also multithreaded
|
||||
(True, False),
|
||||
(True, True),
|
||||
|
||||
),
|
||||
}.items(),
|
||||
name_fn=lambda name, thread_spec: name
|
||||
)
|
||||
@parametrize("work_in_main_thread", [True, False])
|
||||
def test_source_multithreaded(self, name, thread_spec, work_in_main_thread):
|
||||
"""Test various threading configurations.
|
||||
|
||||
`thread_spec` is a Tuple[Tuple[bool, bool], ...] where each pair is a
|
||||
thread. The first bool indicates if the thread should be started under
|
||||
the profiler context and the second is if it should be joined under the
|
||||
profiler context.
|
||||
"""
|
||||
|
||||
timeout = 15
|
||||
num_threads = len(thread_spec) + 1 # Main thread
|
||||
start_barrier = threading.Barrier(num_threads, timeout=timeout)
|
||||
end_barrier = threading.Barrier(num_threads, timeout=timeout)
|
||||
|
||||
class Task(threading.Thread):
|
||||
|
||||
def __init__(self):
|
||||
self._end_gate = threading.Event()
|
||||
super().__init__(daemon=True)
|
||||
self.start()
|
||||
self.finished = False
|
||||
|
||||
def run(self):
|
||||
self._run(self._end_gate)
|
||||
|
||||
def release(self):
|
||||
self._end_gate.set()
|
||||
|
||||
@staticmethod
|
||||
def _run(end_gate=None):
|
||||
|
||||
def known_preexisting_function():
|
||||
start_barrier.wait()
|
||||
|
||||
# Fixed point that we can use to test capture of functions
|
||||
# which are already running when profiling is enabled.
|
||||
known_preexisting_function()
|
||||
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(10, 10),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
|
||||
def invoked_during_run():
|
||||
pass
|
||||
|
||||
invoked_during_run()
|
||||
|
||||
_ = model(torch.rand(4, 10))
|
||||
end_barrier.wait()
|
||||
|
||||
if end_gate is not None:
|
||||
end_gate.wait(timeout=timeout)
|
||||
|
||||
threads = {}
|
||||
|
||||
def add_threads(context: bool):
|
||||
for idx, (start_under_profiler, _) in enumerate(thread_spec):
|
||||
if start_under_profiler == context:
|
||||
assert idx not in threads
|
||||
threads[idx] = Task()
|
||||
|
||||
def join_threads(context: bool):
|
||||
for idx, (_, end_under_profiler) in enumerate(thread_spec):
|
||||
if end_under_profiler == context:
|
||||
threads[idx].release()
|
||||
|
||||
for idx, (_, end_under_profiler) in enumerate(thread_spec):
|
||||
t = threads[idx]
|
||||
if end_under_profiler == context:
|
||||
t.join(timeout=timeout)
|
||||
|
||||
try:
|
||||
add_threads(False)
|
||||
with torch.profiler.profile(with_stack=True) as prof:
|
||||
# Threads added while the profiler are running will not be observed
|
||||
# since there is no way to hook into Python's thread start call to
|
||||
# register the observer. These are here purely to verify safety.
|
||||
add_threads(True)
|
||||
|
||||
if work_in_main_thread:
|
||||
Task._run()
|
||||
else:
|
||||
start_barrier.wait()
|
||||
end_barrier.wait()
|
||||
|
||||
join_threads(True)
|
||||
join_threads(False)
|
||||
|
||||
finally:
|
||||
# It is very important that we clean up everything because the
|
||||
# Python tracer will detect ALL active threads. (Even orphans from
|
||||
# prior failed tests.) If we don't clean up properly we can
|
||||
# contaminate subsequent tests.
|
||||
start_barrier.abort()
|
||||
end_barrier.abort()
|
||||
for t in threads.values():
|
||||
t.release()
|
||||
|
||||
for t in threads.values():
|
||||
t.join(timeout=timeout)
|
||||
|
||||
for t in threads.values():
|
||||
self.assertFalse(t.is_alive())
|
||||
|
||||
roots = prof.profiler.kineto_results.experimental_event_tree()
|
||||
nodes = [node for node in _utils.traverse_dfs(roots) if isinstance(node.extra_fields, _ExtraFields_PyCall)]
|
||||
tid_counts = collections.Counter([node.start_tid for node in nodes])
|
||||
|
||||
prior_threads = sum(not start_under_profiler for start_under_profiler, _ in thread_spec)
|
||||
expected_threads = prior_threads + 1
|
||||
self.assertEqual(len(tid_counts), expected_threads, f"{expected_threads}, {tid_counts}")
|
||||
self.assertEqual(len(nodes), sum(tid_counts.values()))
|
||||
|
||||
# Profiler uses uint64_t max as a placeholder until TID can be determined.
|
||||
no_tid = 2 ** 64 - 1
|
||||
self.assertFalse(no_tid in tid_counts)
|
||||
|
||||
worker_threads = prior_threads + (1 if work_in_main_thread else 0)
|
||||
|
||||
observed_preexisting = [node.start_tid for node in nodes if "known_preexisting_function" in node.name]
|
||||
self.assertEqual(len(observed_preexisting), worker_threads)
|
||||
self.assertEqual(len(observed_preexisting), len(set(observed_preexisting)))
|
||||
|
||||
observed_during_run = [node.start_tid for node in nodes if "invoked_during_run" in node.name]
|
||||
self.assertEqual(len(observed_during_run), worker_threads)
|
||||
self.assertEqual(len(observed_during_run), len(set(observed_during_run)))
|
||||
|
||||
def payload(self, use_cuda=False):
|
||||
x = torch.randn(10, 10)
|
||||
if use_cuda:
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ namespace {
|
|||
enum CallType { PyCall = 0, PyModuleCall, PyCCall, PyOptimizerCall };
|
||||
static constexpr size_t CallTypeSize = 4;
|
||||
using no_ephemeral_t = std::tuple<>;
|
||||
static constexpr uint64_t NoTID = std::numeric_limits<uint64_t>::max();
|
||||
|
||||
// ============================================================================
|
||||
// == Miscellaneous structs and utils =========================================
|
||||
|
|
@ -600,6 +601,29 @@ static PyTypeObject TraceContextType = {
|
|||
nullptr /* tp_free */
|
||||
};
|
||||
|
||||
class gil_and_restore_thread {
|
||||
public:
|
||||
gil_and_restore_thread()
|
||||
: gil_(), initial_thread_state_{PyThreadState_Get()} {}
|
||||
~gil_and_restore_thread() {
|
||||
PyThreadState_Swap(initial_thread_state_);
|
||||
|
||||
// `gil_scoped_acquire` is a bit fragile in on-demand mode:
|
||||
// https://github.com/pytorch/pytorch/pull/91684#issuecomment-1413154458
|
||||
if (!Py_IsInitialized()) {
|
||||
gil_.disarm();
|
||||
}
|
||||
}
|
||||
|
||||
PyThreadState* initial_thread_state() const {
|
||||
return initial_thread_state_;
|
||||
}
|
||||
|
||||
private:
|
||||
pybind11::gil_scoped_acquire gil_;
|
||||
PyThreadState* initial_thread_state_;
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// == Thread local cache ======================================================
|
||||
// ============================================================================
|
||||
|
|
@ -666,26 +690,53 @@ class PythonTracer final : public python_tracer::PythonTracerBase {
|
|||
std::vector<python_tracer::CompressedEvent>& enters,
|
||||
time_t end_time_ns) override;
|
||||
|
||||
struct StartFrame {
|
||||
TraceKey trace_key_;
|
||||
approx_time_t start_time;
|
||||
};
|
||||
|
||||
private:
|
||||
void recordPyCall(ThreadLocalResults& tls, PyFrameObject* frame);
|
||||
void recordPyCall(
|
||||
ThreadLocalResults& tls,
|
||||
PyFrameObject* frame,
|
||||
bool is_startup_frame);
|
||||
|
||||
void recordCCall(
|
||||
ThreadLocalResults& tls,
|
||||
PyFrameObject* frame,
|
||||
PyObject* arg);
|
||||
|
||||
const std::vector<PyThreadState*> interpreterThreads() const;
|
||||
|
||||
std::atomic<bool> active_lock_{false};
|
||||
bool active_{false};
|
||||
|
||||
torch::profiler::impl::RecordQueue* queue_;
|
||||
PyInterpreterState* interpreter_;
|
||||
PyCodeObject* module_call_code_;
|
||||
PyCodeObject* optimizer_hook_;
|
||||
|
||||
std::vector<StartFrame> start_frames_;
|
||||
std::deque<ThreadLocalResults> thread_local_results_;
|
||||
ValueCache value_cache_;
|
||||
};
|
||||
|
||||
const std::vector<PyThreadState*> PythonTracer::interpreterThreads() const {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
std::vector<PyThreadState*> out;
|
||||
if (SOFT_ASSERT(interpreter_)) {
|
||||
auto* thread_state = PyInterpreterState_ThreadHead(interpreter_);
|
||||
while (thread_state != nullptr) {
|
||||
out.push_back(thread_state);
|
||||
thread_state = PyThreadState_Next(thread_state);
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
|
||||
: queue_(queue),
|
||||
interpreter_(nullptr),
|
||||
module_call_code_(getCode<CallType::PyModuleCall>()),
|
||||
optimizer_hook_(getCode<CallType::PyOptimizerCall>()) {
|
||||
TORCH_CHECK(queue_ != nullptr);
|
||||
|
|
@ -699,29 +750,16 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
|
|||
return;
|
||||
}
|
||||
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
gil_and_restore_thread gil;
|
||||
interpreter_ = PyInterpreterState_Get();
|
||||
|
||||
// Loop over all threads within the current interpreter. We will need to
|
||||
// register a trace function with each thread. We set the current thread to
|
||||
// position zero to ensure that it is traced, and so we can restore the
|
||||
// thread state after registration. The profiler cannot post process multiple
|
||||
// python threads yet, so this section is temporarily disabled.
|
||||
std::vector<PyThreadState*> thread_states{PyThreadState_Get()};
|
||||
/*
|
||||
if (all_threads) {
|
||||
auto thread_state = thread_states[0];
|
||||
while (thread_state != nullptr) {
|
||||
if (thread_state != thread_states[0]) {
|
||||
thread_states.push_back(thread_state);
|
||||
}
|
||||
thread_state = PyThreadState_Next(thread_state);
|
||||
}
|
||||
if (!gil.initial_thread_state()) {
|
||||
TORCH_WARN("PyThreadState_Get returned NULL");
|
||||
return;
|
||||
}
|
||||
*/
|
||||
|
||||
// Register the tracer in each thread.
|
||||
for (const auto i : c10::irange(thread_states.size())) {
|
||||
PyThreadState* thread_state = thread_states[i];
|
||||
for (const auto thread_state : interpreterThreads()) {
|
||||
PyThreadState_Swap(thread_state);
|
||||
|
||||
thread_local_results_.emplace_back(thread_state, &value_cache_, this);
|
||||
|
|
@ -747,7 +785,7 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
|
|||
}
|
||||
|
||||
for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) {
|
||||
recordPyCall(thread_local_results_.back(), it->get());
|
||||
recordPyCall(thread_local_results_.back(), it->get(), true);
|
||||
auto frame_refcount = Py_REFCNT(it->get());
|
||||
|
||||
// We hold one reference in `current_stack`, and the interpreter holds
|
||||
|
|
@ -760,20 +798,17 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
|
|||
// cannot be round tripped via `sys.settrace(sys.gettrace())`
|
||||
PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx);
|
||||
}
|
||||
|
||||
// Restore the thread state to its initial value.
|
||||
PyThreadState_Swap(thread_states[0]);
|
||||
};
|
||||
|
||||
void PythonTracer::stop() {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
gil_and_restore_thread gil;
|
||||
if (active_) {
|
||||
PyThreadState* initial_thread_state = PyThreadState_Get();
|
||||
for (const auto& i : thread_local_results_) {
|
||||
PyThreadState_Swap(i.thread_state_);
|
||||
PyEval_SetProfile(nullptr, nullptr);
|
||||
for (const auto thread_state : interpreterThreads()) {
|
||||
if (thread_state->c_profilefunc == &PythonTracer::pyProfileFn) {
|
||||
PyThreadState_Swap(thread_state);
|
||||
PyEval_SetProfile(nullptr, nullptr);
|
||||
}
|
||||
}
|
||||
PyThreadState_Swap(initial_thread_state);
|
||||
|
||||
auto lock_returned = active_lock_.compare_exchange_strong(active_, false);
|
||||
active_ = false;
|
||||
|
|
@ -788,9 +823,12 @@ PythonTracer::~PythonTracer() {
|
|||
}
|
||||
}
|
||||
|
||||
void PythonTracer::recordPyCall(ThreadLocalResults& tls, PyFrameObject* frame) {
|
||||
void PythonTracer::recordPyCall(
|
||||
ThreadLocalResults& tls,
|
||||
PyFrameObject* frame,
|
||||
bool is_startup_frame) {
|
||||
static constexpr auto E = EventType::PyCall;
|
||||
auto get_key = [&]() -> TraceKey {
|
||||
const auto key = [&]() -> TraceKey {
|
||||
auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
|
||||
if (code.get() == module_call_code_) {
|
||||
// By default, CPython stores locals in a "fast" format, with an array
|
||||
|
|
@ -822,8 +860,10 @@ void PythonTracer::recordPyCall(ThreadLocalResults& tls, PyFrameObject* frame) {
|
|||
auto f_back = (back.get() != nullptr) ? back.get() : frame;
|
||||
return tls.intern<CallType::PyCall, E>(no_ephemeral_t(), frame, f_back);
|
||||
}
|
||||
};
|
||||
queue_->getSubqueue()->emplace_py_call(get_key(), getApproximateTime());
|
||||
}();
|
||||
const auto time = getApproximateTime();
|
||||
is_startup_frame ? start_frames_.push_back({key, time})
|
||||
: queue_->getSubqueue()->emplace_py_call(key, time);
|
||||
}
|
||||
|
||||
void PythonTracer::recordCCall(
|
||||
|
|
@ -869,6 +909,18 @@ class PostProcess {
|
|||
}
|
||||
}
|
||||
|
||||
void set_start_frames(
|
||||
const std::vector<PythonTracer::StartFrame>& start_frames,
|
||||
std::vector<python_tracer::CompressedEvent>& enters) {
|
||||
for (const auto& frame : start_frames) {
|
||||
enters.push_back(
|
||||
{frame.trace_key_,
|
||||
NoTID, // Allows us to detect unhandled start frames
|
||||
{},
|
||||
time_converter_(frame.start_time)});
|
||||
}
|
||||
}
|
||||
|
||||
template <CallType C>
|
||||
void operator()(
|
||||
const TraceKeyCacheState<C>& trace_cache,
|
||||
|
|
@ -906,6 +958,7 @@ class PostProcess {
|
|||
std::vector<python_tracer::CompressedEvent>& enters,
|
||||
std::vector<std::shared_ptr<Result>>& out) {
|
||||
using stack_t = std::vector<std::shared_ptr<Result>>;
|
||||
const auto initial_size = out.size();
|
||||
auto pop = [](stack_t& stack, time_t t) {
|
||||
TORCH_INTERNAL_ASSERT(stack.size(), "Python replay stack is empty.");
|
||||
c10::get<ExtraFields<E>>(stack.back()->extra_fields_).end_time_ns_ = t;
|
||||
|
|
@ -939,6 +992,25 @@ class PostProcess {
|
|||
pop(i.second, end_time_);
|
||||
}
|
||||
}
|
||||
|
||||
// Assign system TIDs to start events based on the system TID of the next
|
||||
// observed event with the same Python TID.
|
||||
ska::flat_hash_map<size_t, std::pair<size_t, kineto::DeviceAndResource>>
|
||||
tid_map;
|
||||
auto it = out.rbegin();
|
||||
for (C10_UNUSED auto _ : c10::irange(initial_size, out.size())) {
|
||||
const auto python_tid =
|
||||
c10::get<ExtraFields<E>>((*it)->extra_fields_).python_tid_;
|
||||
if ((*it)->start_tid_ == NoTID && SOFT_ASSERT(E == EventType::PyCall)) {
|
||||
const auto& tid_info =
|
||||
tid_map.insert({python_tid, {NoTID, kineto::DeviceAndResource()}})
|
||||
.first->second;
|
||||
(*it)->start_tid_ = tid_info.first;
|
||||
(*it)->kineto_info_ = tid_info.second;
|
||||
}
|
||||
tid_map[python_tid] = {(*it)->start_tid_, (*it)->kineto_info_};
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
template <EventType E>
|
||||
|
|
@ -989,6 +1061,7 @@ std::vector<std::shared_ptr<Result>> PythonTracer::getEvents(
|
|||
thread_local_results_,
|
||||
value_cache_,
|
||||
end_time_ns);
|
||||
post_process.set_start_frames(start_frames_, enters);
|
||||
auto out = post_process.run(enters);
|
||||
|
||||
std::stable_sort(out.begin(), out.end(), [](const auto& a, const auto& b) {
|
||||
|
|
@ -1015,7 +1088,7 @@ int PythonTracer::pyProfileFn(
|
|||
*reinterpret_cast<TraceContext*>(obj)->thread_local_results_;
|
||||
switch (what) {
|
||||
case PyTrace_CALL:
|
||||
local_results.active_tracer_->recordPyCall(local_results, frame);
|
||||
local_results.active_tracer_->recordPyCall(local_results, frame, false);
|
||||
break;
|
||||
|
||||
case PyTrace_C_CALL:
|
||||
|
|
|
|||
Loading…
Reference in a new issue