mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
allow profiling on all threads via experimentalConfig (#143659)
In some situations we want to profile calls coming from all threads (similar to on-demand), not just the thread that started profiling and the spawned threads that would inherit KinetoThreadLocal state. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143659 Approved by: https://github.com/sraikund16
This commit is contained in:
parent
00831f9b22
commit
2ab698e708
5 changed files with 98 additions and 6 deletions
|
|
@ -2161,6 +2161,79 @@ assert KinetoStepTracker.current_step() == initial_step + 2 * niters
|
|||
self.payload(use_cuda=True)
|
||||
validate_json(prof, disable_external_correlation)
|
||||
|
||||
@skipIfTorchDynamo("profiler gets ignored if dynamo activated")
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
|
||||
@unittest.skipIf(not kineto_available(), "Kineto is required")
|
||||
def test_profile_all_threads(self):
|
||||
profiling_started = threading.Event()
|
||||
profiling_ended = threading.Event()
|
||||
n_rep = 5
|
||||
|
||||
def prep_inputs():
|
||||
return [torch.randn(1024, 1024, device="cuda") for _ in range(2)]
|
||||
|
||||
def main_thread_fn(profile_all_threads, returned_events):
|
||||
x, y = prep_inputs()
|
||||
experimental_config = torch._C._profiler._ExperimentalConfig(
|
||||
profile_all_threads=profile_all_threads
|
||||
)
|
||||
with torch.profiler.profile(
|
||||
experimental_config=experimental_config, record_shapes=True
|
||||
) as p:
|
||||
profiling_started.set()
|
||||
for _ in range(n_rep):
|
||||
_ = x @ y
|
||||
profiling_ended.wait()
|
||||
returned_events.append(p.events())
|
||||
|
||||
def side_thread_fn():
|
||||
x, y = prep_inputs()
|
||||
profiling_started.wait()
|
||||
for _ in range(n_rep):
|
||||
_ = x @ y
|
||||
profiling_ended.set()
|
||||
|
||||
def main_with_thread_fn(profile_all_threads):
|
||||
x, y = prep_inputs()
|
||||
experimental_config = torch._C._profiler._ExperimentalConfig(
|
||||
profile_all_threads=profile_all_threads
|
||||
)
|
||||
with torch.profiler.profile(
|
||||
experimental_config=experimental_config, record_shapes=True
|
||||
) as p:
|
||||
side_thread = threading.Thread(target=side_thread_fn)
|
||||
side_thread.start()
|
||||
for _ in range(n_rep):
|
||||
_ = x @ y
|
||||
side_thread.join()
|
||||
return p.events()
|
||||
|
||||
for profile_all_threads in (True, False):
|
||||
returned_events = []
|
||||
main_thread = threading.Thread(
|
||||
target=main_thread_fn, args=(profile_all_threads, returned_events)
|
||||
)
|
||||
side_thread = threading.Thread(target=side_thread_fn)
|
||||
main_thread.start()
|
||||
side_thread.start()
|
||||
main_thread.join()
|
||||
side_thread.join()
|
||||
|
||||
def verify_events(events):
|
||||
mm_events = collections.defaultdict(int)
|
||||
for e in events:
|
||||
if e.name == "aten::mm":
|
||||
mm_events[e.thread] += 1
|
||||
self.assertEqual(e.input_shapes, [[1024, 1024], [1024, 1024]])
|
||||
self.assertEqual(len(mm_events), 1 + int(profile_all_threads))
|
||||
for v in mm_events.values():
|
||||
self.assertEqual(v, n_rep)
|
||||
|
||||
verify_events(returned_events[0])
|
||||
# test spawning thread from within the profiled region
|
||||
events = main_with_thread_fn(profile_all_threads)
|
||||
verify_events(events)
|
||||
|
||||
|
||||
class SimpleNet(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
|
|
|||
|
|
@ -771,8 +771,9 @@ void enableProfiler(
|
|||
KinetoThreadLocalState::push(state_ptr);
|
||||
|
||||
if (has_cpu) {
|
||||
config.global() ? pushProfilingCallbacks</*global=*/true>(scopes)
|
||||
: pushProfilingCallbacks</*global=*/false>(scopes);
|
||||
config.pushGlobalCallbacks()
|
||||
? pushProfilingCallbacks</*global=*/true>(scopes)
|
||||
: pushProfilingCallbacks</*global=*/false>(scopes);
|
||||
}
|
||||
|
||||
if (!config.global()) {
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ ExperimentalConfig::ExperimentalConfig(
|
|||
bool enable_cuda_sync_events,
|
||||
bool adjust_profiler_step,
|
||||
bool disable_external_correlation,
|
||||
bool profile_all_threads,
|
||||
bool adjust_timestamps)
|
||||
: profiler_metrics{std::move(profiler_metrics)},
|
||||
profiler_measure_per_kernel{profiler_measure_per_kernel},
|
||||
|
|
@ -27,6 +28,7 @@ ExperimentalConfig::ExperimentalConfig(
|
|||
enable_cuda_sync_events{enable_cuda_sync_events},
|
||||
adjust_profiler_step{adjust_profiler_step},
|
||||
disable_external_correlation{disable_external_correlation},
|
||||
profile_all_threads{profile_all_threads},
|
||||
adjust_timestamps{adjust_timestamps} {}
|
||||
|
||||
/*explicit*/ ExperimentalConfig::operator bool() const {
|
||||
|
|
@ -59,6 +61,10 @@ bool ProfilerConfig::global() const {
|
|||
return state == torch::profiler::impl::ProfilerState::KINETO_ONDEMAND;
|
||||
}
|
||||
|
||||
bool ProfilerConfig::pushGlobalCallbacks() const {
|
||||
return global() || experimental_config.profile_all_threads;
|
||||
}
|
||||
|
||||
namespace {
|
||||
enum ProfilerIValueIdx {
|
||||
STATE = 0,
|
||||
|
|
@ -114,14 +120,15 @@ ProfilerStateBase::~ProfilerStateBase() {
|
|||
? GlobalManager::get()
|
||||
: static_cast<ProfilerStateBase*>(
|
||||
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE));
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!out || out->config().global() == global);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
!out || out->config().pushGlobalCallbacks() == global);
|
||||
return out;
|
||||
}
|
||||
|
||||
/*static*/ void ProfilerStateBase::push(
|
||||
std::shared_ptr<ProfilerStateBase>&& state) {
|
||||
TORCH_INTERNAL_ASSERT(state != nullptr);
|
||||
if (state->config().global()) {
|
||||
if (state->config().pushGlobalCallbacks()) {
|
||||
GlobalManager::push(std::move(state));
|
||||
} else {
|
||||
c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state);
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ struct TORCH_API ExperimentalConfig {
|
|||
bool enable_cuda_sync_events = false,
|
||||
bool adjust_profiler_step = false,
|
||||
bool disable_external_correlation = false,
|
||||
bool profile_all_threads = false,
|
||||
bool adjust_timestamps = false);
|
||||
explicit operator bool() const;
|
||||
|
||||
|
|
@ -89,6 +90,11 @@ struct TORCH_API ExperimentalConfig {
|
|||
*/
|
||||
bool disable_external_correlation;
|
||||
|
||||
/* controls whether profiler records cpu events on threads
|
||||
* that are not spawned from the main thread on which the
|
||||
* profiler was enabled, similar to on_demand mode */
|
||||
bool profile_all_threads;
|
||||
|
||||
/*
|
||||
* Controls whether or not timestamp adjustment occurs after profiling.
|
||||
* The purpose of this is to adjust Vulkan event timelines to align with those
|
||||
|
|
@ -115,6 +121,7 @@ struct TORCH_API ProfilerConfig {
|
|||
|
||||
bool disabled() const;
|
||||
bool global() const;
|
||||
bool pushGlobalCallbacks() const;
|
||||
|
||||
ProfilerState state;
|
||||
ExperimentalConfig experimental_config;
|
||||
|
|
|
|||
|
|
@ -337,7 +337,8 @@ void initPythonBindings(PyObject* module) {
|
|||
std::vector<std::string> /* performance_events */,
|
||||
bool /* enable_cuda_sync_events */,
|
||||
bool /* adjust_profiler_step */,
|
||||
bool /* disable_external_correlation*/
|
||||
bool /* disable_external_correlation*/,
|
||||
bool /* profile_all_threads */
|
||||
>(),
|
||||
"An experimental config for Kineto features. Please note that"
|
||||
"backward compatibility is not guaranteed.\n"
|
||||
|
|
@ -354,13 +355,15 @@ void initPythonBindings(PyObject* module) {
|
|||
" adjust_profiler_step (bool) : whether to adjust the profiler step to\n"
|
||||
" match the parent python event duration. This feature is new and currently disabled by default.\n",
|
||||
" disable_external_correlation (bool) : whether to disable external correlation\n",
|
||||
" profile_all_threads (bool) : whether to profile all threads\n",
|
||||
py::arg("profiler_metrics") = std::vector<std::string>(),
|
||||
py::arg("profiler_measure_per_kernel") = false,
|
||||
py::arg("verbose") = false,
|
||||
py::arg("performance_events") = std::vector<std::string>(),
|
||||
py::arg("enable_cuda_sync_events") = false,
|
||||
py::arg("adjust_profiler_step") = false,
|
||||
py::arg("disable_external_correlation") = false)
|
||||
py::arg("disable_external_correlation") = false,
|
||||
py::arg("profile_all_threads") = false)
|
||||
.def(py::pickle(
|
||||
[](const ExperimentalConfig& p) { // __getstate__
|
||||
py::list py_metrics;
|
||||
|
|
@ -381,6 +384,7 @@ void initPythonBindings(PyObject* module) {
|
|||
p.enable_cuda_sync_events,
|
||||
p.adjust_profiler_step,
|
||||
p.disable_external_correlation,
|
||||
p.profile_all_threads,
|
||||
p.performance_events);
|
||||
},
|
||||
[](const py::tuple& t) { // __setstate__
|
||||
|
|
|
|||
Loading…
Reference in a new issue