mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[c10d][fr] flight recorder improvements (#143446)
Summary: 1. Flight recorder dumps are now automatically dumped by default upon timeout or exception. Users don't need to opt-in. 2. Change default dump location to running user's home directory `.cache` folder. Test Plan: 1. Tested locally by running the crash program from flight recorder tutorial page. https://pytorch.org/tutorials/prototype/flight_recorder_tutorial.html#an-end-to-end-example 2. Noted that flight recorder files were correctly created. ❯ pwd /home/cpio/.cache/fr_trace ❯ ls nccl_trace_rank_0 nccl_trace_rank_1 Differential Revision: [D67363720](https://our.internmc.facebook.com/intern/diff/D67363720) Pull Request resolved: https://github.com/pytorch/pytorch/pull/143446 Approved by: https://github.com/d4l3k
This commit is contained in:
parent
a94f259a69
commit
485497e727
3 changed files with 56 additions and 2 deletions
|
|
@ -2901,6 +2901,8 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
|||
)
|
||||
# avoid watchdog thread interference
|
||||
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
|
||||
# set heartbeat timeout to a small value so that we don't wait too long for things to shutdown
|
||||
os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = "5"
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
process_group = c10d.ProcessGroupNCCL(
|
||||
store,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,11 @@
|
|||
|
||||
#include <cuda_runtime.h>
|
||||
#include <nlohmann/json.hpp>
|
||||
#ifndef _WIN32
|
||||
#include <sys/stat.h>
|
||||
#else
|
||||
#include <direct.h>
|
||||
#endif
|
||||
#include <fstream>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
|
|
@ -108,6 +113,45 @@ control_plane::RegisterHandler jsonDumpHandler{
|
|||
"application/json");
|
||||
}};
|
||||
|
||||
bool recursive_mkdir(const std::string& dir) {
|
||||
// Check if current dir exists
|
||||
const char* p_dir = dir.c_str();
|
||||
const bool dir_exists = (access(p_dir, F_OK) == 0);
|
||||
if (dir_exists) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Find folder separator and check if we are at the top
|
||||
auto pos = dir.find_last_of("/\\");
|
||||
if (pos == std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Try to create parent directory
|
||||
if (!(recursive_mkdir(dir.substr(0, pos)))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Try to create current directory
|
||||
#ifdef _WIN32
|
||||
int ret = _mkdir(dir.c_str());
|
||||
#else
|
||||
int ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG);
|
||||
#endif
|
||||
// Success
|
||||
if (ret == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Try to create complete path again
|
||||
#ifdef _WIN32
|
||||
ret = _mkdir(dir.c_str());
|
||||
#else
|
||||
ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG);
|
||||
#endif
|
||||
return ret == 0;
|
||||
}
|
||||
|
||||
void DebugInfoWriter::write(const std::string& trace) {
|
||||
// Open a file for writing. The ios::binary flag is used to write data as
|
||||
// binary.
|
||||
|
|
@ -131,8 +175,16 @@ void DebugInfoWriter::write(const std::string& trace) {
|
|||
|
||||
DebugInfoWriter& DebugInfoWriter::getWriter(int rank) {
|
||||
if (writer_ == nullptr) {
|
||||
// Attempt to write to running user's HOME directory cache folder - if it
|
||||
// exists.
|
||||
auto homeDir = getCvarString({"HOME"}, "/tmp");
|
||||
std::string cacheDirPath = homeDir + "/.cache/torch";
|
||||
// Create the .cache directory if it doesn't exist
|
||||
recursive_mkdir(cacheDirPath);
|
||||
std::string defaultLocation = cacheDirPath + "/" + "nccl_trace_rank_";
|
||||
|
||||
std::string fileNamePrefix = getCvarString(
|
||||
{"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
|
||||
{"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, defaultLocation.c_str());
|
||||
// Using std::unique_ptr here to auto-delete the writer object
|
||||
// when the pointer itself is destroyed.
|
||||
std::unique_ptr<DebugInfoWriter> writerPtr(
|
||||
|
|
|
|||
|
|
@ -896,7 +896,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
|||
// TODO, we should either deprecate TORCH_NCCL_DUMP_ON_TIMEOUT
|
||||
// or change its name to reflect that dump happens on exception including
|
||||
// both timeout and other errors.
|
||||
dumpOnTimeoutOrEx_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, false) ||
|
||||
dumpOnTimeoutOrEx_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, true) ||
|
||||
(dist_debug_level_ >= DebugLevel::Detail);
|
||||
// logging C++ stack isn't safe. Introduce a variable to control it.
|
||||
logCppStackOnUncleanShutdown_ =
|
||||
|
|
|
|||
Loading…
Reference in a new issue