From 1cd70e7e237d03de3f50445ab0c7975d6749dc5e Mon Sep 17 00:00:00 2001 From: Chirag Pandya Date: Thu, 26 Dec 2024 12:20:06 -0800 Subject: [PATCH] [fr][c10d] log trace capture enabled or not in flight recorder (#143865) Summary: Refactor logging for flight recorder so we can log if the capture was with or without stack trace capture enabled. We introduce a new column ('trace_enabled') in the logger. Test Plan: Tested on local job and noted that correct output was produced. Internal link: https://fburl.com/scuba/c10d_flight_recorder/ulhqnmhg Pull Request resolved: https://github.com/pytorch/pytorch/pull/143865 Approved by: https://github.com/fduwjj --- .../distributed/c10d/ProcessGroupNCCL.cpp | 57 ++++++++----------- .../distributed/c10d/ProcessGroupNCCL.hpp | 5 +- torch/csrc/distributed/c10d/logger.hpp | 2 + 3 files changed, 29 insertions(+), 35 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 84b2ddc9759..df8de61474a 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -29,7 +29,6 @@ #include #include #include -#include #include #include @@ -1275,20 +1274,11 @@ bool ProcessGroupNCCL::waitForFutureOrTimeout( std::future& fut, const std::chrono::milliseconds& timeOutMilSec, const std::string& futDescription, - bool throwException, - bool log) { + ::c10d::C10dLoggingData& debugLog, + bool throwException) { std::string errorMsg; bool complete = false; - ::c10d::C10dLoggingData data; - if (log) { - data.integers["pg_id"] = static_cast(local_id_); - data.integers["rank"] = rank_; - data.integers["global_rank"] = globalRank(); - data.integers["world_size"] = getSize(); - data.strings["flight_recorder_version"] = c10d::version_val_str; - } - TORCH_CHECK(fut.valid(), "Expected a valid future"); std::future_status status = fut.wait_for(timeOutMilSec); if (status == std::future_status::ready) { @@ -1299,9 +1289,7 @@ bool ProcessGroupNCCL::waitForFutureOrTimeout( if (result) { VLOG(2) << logPrefix() << "future successfully executed for: " << futDescription; - if (log) { - data.strings["status"] = "SUCCESS"; - } + debugLog.strings["status"] = "SUCCESS"; complete = true; } } catch (const std::exception& e) { @@ -1311,20 +1299,17 @@ bool ProcessGroupNCCL::waitForFutureOrTimeout( futDescription, ": ", e.what()); - if (log) { - data.strings["status"] = "EXCEPTION"; - data.strings["exception"] = e.what(); - } + + debugLog.strings["status"] = "EXCEPTION"; + debugLog.strings["exception"] = e.what(); LOG(ERROR) << errorMsg; } catch (...) { errorMsg = c10::str( logPrefix(), "Unknown exception thrown when waiting for future ", futDescription); - if (log) { - data.strings["status"] = "EXCEPTION"; - data.strings["exception"] = "Unknown exception"; - } + debugLog.strings["status"] = "EXCEPTION"; + debugLog.strings["exception"] = "Unknown exception"; LOG(ERROR) << errorMsg; } } else { @@ -1335,15 +1320,9 @@ bool ProcessGroupNCCL::waitForFutureOrTimeout( " timed out after ", timeOutMilSec.count(), " ms"); - data.strings["status"] = "TIMEOUT"; + debugLog.strings["status"] = "TIMEOUT"; LOG(ERROR) << errorMsg; } - if (log) { - auto logger = c10d::C10dLogger::getLogger(); - if (logger) { - logger->log(data); - } - } if (throwException && !errorMsg.empty()) { C10_THROW_ERROR(DistBackendError, errorMsg); } @@ -1418,8 +1397,9 @@ void ProcessGroupNCCL::abort() { std::future fut = std::async(std::launch::async, [this]() { return this->abortComms(); }); + ::c10d::C10dLoggingData debugLog; waitForFutureOrTimeout( - fut, options_->timeout, "ProcessGroup abort", true, false); + fut, options_->timeout, "ProcessGroup abort", debugLog, true); LOG(INFO) << logPrefix() << "ProcessGroupNCCL aborts successfully."; // We need to wait for abort to finish before we can safely shut down @@ -1765,6 +1745,12 @@ void ProcessGroupNCCL::heartbeatMonitor() { // Store debug info to storage if no other thread does it. (By default to // local disk) bool dumpStackTrace = true; + ::c10d::C10dLoggingData debugLog; + debugLog.integers["pg_id"] = static_cast(local_id_); + debugLog.integers["rank"] = rank_; + debugLog.integers["global_rank"] = globalRank(); + debugLog.integers["world_size"] = getSize(); + debugLog.strings["flight_recorder_version"] = c10d::version_val_str; for (int i = 0; i < 2; i++) { std::future asyncDebugDump = std::async(std::launch::async, [this, dumpStackTrace]() { @@ -1776,8 +1762,8 @@ void ProcessGroupNCCL::heartbeatMonitor() { asyncDebugDump, std::chrono::milliseconds(waitTimeoutDumpInMilSec_), "Flight recorder dump in heartbeatMonitor", - false, - true); + debugLog, + false); if (complete) { LOG(INFO) @@ -1789,6 +1775,11 @@ void ProcessGroupNCCL::heartbeatMonitor() { // iteration. dumpStackTrace = false; } + debugLog.integers["trace_enabled"] = int64_t(dumpStackTrace); + auto logger = c10d::C10dLogger::getLogger(); + if (logger) { + logger->log(debugLog); + } // Indicate to watchdog thread that we have finished dumping. promiseFlightRecorderDump_.set_value(); } diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index c77bbb5e501..9a9f69cd940 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -985,8 +986,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::future& fut, const std::chrono::milliseconds& timeOutMilSec, const std::string& futDescription, - bool throwException = false, - bool log = false); + ::c10d::C10dLoggingData& debugLog, + bool throwException = false); std::string getNCCLWatchdogTimeoutErrorMsg(const std::string& extraMsg); diff --git a/torch/csrc/distributed/c10d/logger.hpp b/torch/csrc/distributed/c10d/logger.hpp index d2949a4f674..c1797046a97 100644 --- a/torch/csrc/distributed/c10d/logger.hpp +++ b/torch/csrc/distributed/c10d/logger.hpp @@ -1,3 +1,5 @@ +#pragma once + #include #include