[fr][c10d] log trace capture enabled or not in flight recorder (#143865)

Summary:
Refactor logging for flight recorder so we can log if the capture was
with or without stack trace capture enabled.
We introduce a new column ('trace_enabled') in the logger.

Test Plan:
Tested on local job and noted that correct output was produced.
Internal link: https://fburl.com/scuba/c10d_flight_recorder/ulhqnmhg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143865
Approved by: https://github.com/fduwjj
This commit is contained in:
Chirag Pandya 2024-12-26 12:20:06 -08:00 committed by PyTorch MergeBot
parent 6bdf2addc5
commit 1cd70e7e23
3 changed files with 29 additions and 35 deletions

View file

@ -29,7 +29,6 @@
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#include <torch/csrc/distributed/c10d/TraceUtils.h>
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <torch/csrc/distributed/c10d/logger.hpp>
#include <torch/torch.h>
#include <optional>
@ -1275,20 +1274,11 @@ bool ProcessGroupNCCL::waitForFutureOrTimeout(
std::future<bool>& fut,
const std::chrono::milliseconds& timeOutMilSec,
const std::string& futDescription,
bool throwException,
bool log) {
::c10d::C10dLoggingData& debugLog,
bool throwException) {
std::string errorMsg;
bool complete = false;
::c10d::C10dLoggingData data;
if (log) {
data.integers["pg_id"] = static_cast<int64_t>(local_id_);
data.integers["rank"] = rank_;
data.integers["global_rank"] = globalRank();
data.integers["world_size"] = getSize();
data.strings["flight_recorder_version"] = c10d::version_val_str;
}
TORCH_CHECK(fut.valid(), "Expected a valid future");
std::future_status status = fut.wait_for(timeOutMilSec);
if (status == std::future_status::ready) {
@ -1299,9 +1289,7 @@ bool ProcessGroupNCCL::waitForFutureOrTimeout(
if (result) {
VLOG(2) << logPrefix()
<< "future successfully executed for: " << futDescription;
if (log) {
data.strings["status"] = "SUCCESS";
}
debugLog.strings["status"] = "SUCCESS";
complete = true;
}
} catch (const std::exception& e) {
@ -1311,20 +1299,17 @@ bool ProcessGroupNCCL::waitForFutureOrTimeout(
futDescription,
": ",
e.what());
if (log) {
data.strings["status"] = "EXCEPTION";
data.strings["exception"] = e.what();
}
debugLog.strings["status"] = "EXCEPTION";
debugLog.strings["exception"] = e.what();
LOG(ERROR) << errorMsg;
} catch (...) {
errorMsg = c10::str(
logPrefix(),
"Unknown exception thrown when waiting for future ",
futDescription);
if (log) {
data.strings["status"] = "EXCEPTION";
data.strings["exception"] = "Unknown exception";
}
debugLog.strings["status"] = "EXCEPTION";
debugLog.strings["exception"] = "Unknown exception";
LOG(ERROR) << errorMsg;
}
} else {
@ -1335,15 +1320,9 @@ bool ProcessGroupNCCL::waitForFutureOrTimeout(
" timed out after ",
timeOutMilSec.count(),
" ms");
data.strings["status"] = "TIMEOUT";
debugLog.strings["status"] = "TIMEOUT";
LOG(ERROR) << errorMsg;
}
if (log) {
auto logger = c10d::C10dLogger::getLogger();
if (logger) {
logger->log(data);
}
}
if (throwException && !errorMsg.empty()) {
C10_THROW_ERROR(DistBackendError, errorMsg);
}
@ -1418,8 +1397,9 @@ void ProcessGroupNCCL::abort() {
std::future<bool> fut =
std::async(std::launch::async, [this]() { return this->abortComms(); });
::c10d::C10dLoggingData debugLog;
waitForFutureOrTimeout(
fut, options_->timeout, "ProcessGroup abort", true, false);
fut, options_->timeout, "ProcessGroup abort", debugLog, true);
LOG(INFO) << logPrefix() << "ProcessGroupNCCL aborts successfully.";
// We need to wait for abort to finish before we can safely shut down
@ -1765,6 +1745,12 @@ void ProcessGroupNCCL::heartbeatMonitor() {
// Store debug info to storage if no other thread does it. (By default to
// local disk)
bool dumpStackTrace = true;
::c10d::C10dLoggingData debugLog;
debugLog.integers["pg_id"] = static_cast<int64_t>(local_id_);
debugLog.integers["rank"] = rank_;
debugLog.integers["global_rank"] = globalRank();
debugLog.integers["world_size"] = getSize();
debugLog.strings["flight_recorder_version"] = c10d::version_val_str;
for (int i = 0; i < 2; i++) {
std::future<bool> asyncDebugDump =
std::async(std::launch::async, [this, dumpStackTrace]() {
@ -1776,8 +1762,8 @@ void ProcessGroupNCCL::heartbeatMonitor() {
asyncDebugDump,
std::chrono::milliseconds(waitTimeoutDumpInMilSec_),
"Flight recorder dump in heartbeatMonitor",
false,
true);
debugLog,
false);
if (complete) {
LOG(INFO)
@ -1789,6 +1775,11 @@ void ProcessGroupNCCL::heartbeatMonitor() {
// iteration.
dumpStackTrace = false;
}
debugLog.integers["trace_enabled"] = int64_t(dumpStackTrace);
auto logger = c10d::C10dLogger::getLogger();
if (logger) {
logger->log(debugLog);
}
// Indicate to watchdog thread that we have finished dumping.
promiseFlightRecorderDump_.set_value();
}

View file

@ -24,6 +24,7 @@
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
#include <torch/csrc/distributed/c10d/logger.hpp>
#include <ATen/DynamicLibrary.h>
#include <ATen/cuda/CUDAContext.h>
@ -985,8 +986,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
std::future<bool>& fut,
const std::chrono::milliseconds& timeOutMilSec,
const std::string& futDescription,
bool throwException = false,
bool log = false);
::c10d::C10dLoggingData& debugLog,
bool throwException = false);
std::string getNCCLWatchdogTimeoutErrorMsg(const std::string& extraMsg);

View file

@ -1,3 +1,5 @@
#pragma once
#include <c10/util/Logging.h>
#include <torch/csrc/distributed/c10d/reducer.hpp>