mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55976 - Define a concrete `DebugInfo` to collect Param comms. - Add a macro to easily log `DebugInfo` Test Plan: Tested on `ads:simplified_launcher` with `dyno gputrace` locally tested in libkinetoObserver that it can collect the debug Infobase Reviewed By: kingchc, ilia-cher Differential Revision: D26773447 fbshipit-source-id: a8eeede2d6dbf34d7a1b3614843b4a1baba94448
87 lines
2.5 KiB
C++
87 lines
2.5 KiB
C++
#pragma once
|
|
|
|
#include <c10/macros/Export.h>
|
|
#include <c10/util/Exception.h>
|
|
|
|
#include <memory>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
|
|
namespace c10 {
|
|
|
|
enum class C10_API_ENUM DebugInfoKind : uint8_t {
|
|
PRODUCER_INFO = 0,
|
|
MOBILE_RUNTIME_INFO,
|
|
PROFILER_STATE,
|
|
INFERENCE_CONTEXT, // for inference usage
|
|
PARAM_COMMS_INFO,
|
|
|
|
TEST_INFO, // used only in tests
|
|
TEST_INFO_2, // used only in tests
|
|
};
|
|
|
|
class C10_API DebugInfoBase {
|
|
public:
|
|
DebugInfoBase() {}
|
|
virtual ~DebugInfoBase() {}
|
|
};
|
|
|
|
// Thread local debug information is propagated across the forward
|
|
// (including async fork tasks) and backward passes and is supposed
|
|
// to be utilized by the user's code to pass extra information from
|
|
// the higher layers (e.g. model id) down to the lower levels
|
|
// (e.g. to the operator observers used for debugging, logging,
|
|
// profiling, etc)
|
|
class C10_API ThreadLocalDebugInfo {
|
|
public:
|
|
static DebugInfoBase* get(DebugInfoKind kind);
|
|
|
|
// Get current ThreadLocalDebugInfo
|
|
static std::shared_ptr<ThreadLocalDebugInfo> current();
|
|
|
|
// Internal, use DebugInfoGuard/ThreadLocalStateGuard
|
|
static void _forceCurrentDebugInfo(
|
|
const std::shared_ptr<ThreadLocalDebugInfo>& info);
|
|
|
|
// Push debug info struct of a given kind
|
|
static void _push(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);
|
|
// Pop debug info, throws in case the last pushed
|
|
// debug info is not of a given kind
|
|
static std::shared_ptr<DebugInfoBase> _pop(DebugInfoKind kind);
|
|
// Peek debug info, throws in case the last pushed debug info is not of the
|
|
// given kind
|
|
static std::shared_ptr<DebugInfoBase> _peek(DebugInfoKind kind);
|
|
|
|
private:
|
|
std::shared_ptr<DebugInfoBase> info_;
|
|
DebugInfoKind kind_;
|
|
std::shared_ptr<ThreadLocalDebugInfo> parent_info_;
|
|
|
|
friend class DebugInfoGuard;
|
|
};
|
|
|
|
// DebugInfoGuard is used to set debug information,
|
|
// ThreadLocalDebugInfo is semantically immutable, the values are set
|
|
// through the scope-based guard object.
|
|
// Nested DebugInfoGuard adds/overrides existing values in the scope,
|
|
// restoring the original values after exiting the scope.
|
|
// Users can access the values through the ThreadLocalDebugInfo::get() call;
|
|
class C10_API DebugInfoGuard {
|
|
public:
|
|
DebugInfoGuard(
|
|
DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);
|
|
|
|
explicit DebugInfoGuard(
|
|
std::shared_ptr<ThreadLocalDebugInfo> info);
|
|
|
|
~DebugInfoGuard();
|
|
|
|
DebugInfoGuard(const DebugInfoGuard&) = delete;
|
|
DebugInfoGuard(DebugInfoGuard&&) = delete;
|
|
|
|
private:
|
|
bool active_ = false;
|
|
std::shared_ptr<ThreadLocalDebugInfo> prev_info_ = nullptr;
|
|
};
|
|
|
|
} // namespace c10
|