pytorch/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp
Hongyi Jia 146a7f68e2 Enable desync root cause analysis for NCCL (#68310)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68310

Enable desync root cause analysis by recording the last footprint of collective calls. When timeout we parse the store trace and figure out the root cause of the desync issue. This feature is built based on async error handling.

Test Plan:
Standalone test
* Typical desync - P467288969
* Mismatched collectives - P467288916
* Mismatched broadcast size - P467288873

DDP benchmark
* DDP benchmark desync - P467433483, P467520195

No perf regression:
* w/o this diff https://www.internalfb.com/intern/fblearner/details/308379789?tab=Outputs
* w/ this diff https://www.internalfb.com/intern/fblearner/details/308534088?tab=Outputs

Reviewed By: mingzhe09088

Differential Revision: D32348647

fbshipit-source-id: 43e7e96e3fa2be0ac66c1325bceb639b461a8b3a
2021-11-17 20:29:03 -08:00

258 lines
7 KiB
C++

#include <chrono>
#include <c10/util/irange.h>
#include <c10d/FileStore.hpp>
#include <c10d/ProcessGroupNCCL.hpp>
#include "CUDATest.hpp"
#include "TestUtils.hpp"
#include <torch/csrc/cuda/nccl.h>
#include <gtest/gtest.h>
using namespace c10d::test;
constexpr int kNcclErrorHandlingVersion = 2400;
class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
public:
WorkNCCLSimulateErrors(
const std::vector<at::Device>& devices,
bool simulate_error,
int rank,
c10d::OpType opType,
uint64_t seq)
: WorkNCCL(devices, rank, opType, seq), simulate_error_(simulate_error) {}
std::exception_ptr checkForNCCLErrors(
const std::vector<std::shared_ptr<c10d::NCCLComm>>& ncclComms)
const override {
if (simulate_error_) {
return std::make_exception_ptr(std::runtime_error("Error"));
}
return c10d::ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors(ncclComms);
}
private:
bool simulate_error_;
};
class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
public:
ProcessGroupNCCLSimulateErrors(
const c10::intrusive_ptr<c10d::Store>& store,
int rank,
int size,
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
: ProcessGroupNCCL(store, rank, size, opts), simulate_error_(false) {}
std::exception_ptr checkForNCCLErrors(
const std::vector<std::shared_ptr<c10d::NCCLComm>>& ncclComms) override {
if (simulate_error_) {
return std::make_exception_ptr(std::runtime_error("Error"));
}
return c10d::ProcessGroupNCCL::checkForNCCLErrors(ncclComms);
}
std::chrono::duration<int64_t, std::milli> getWatchdogSleepInterval() {
return std::chrono::milliseconds(
ProcessGroupNCCLSimulateErrors::kWatchdogThreadSleepMillis);
}
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
std::vector<at::Device> devices,
int rank,
c10d::OpType opType,
const char* profilingTitle, const c10::optional<std::vector<at::Tensor>>& inputs = c10::nullopt) override {
return c10::make_intrusive<WorkNCCLSimulateErrors>(
devices, simulate_error_, rank, opType, seq_);
}
size_t getNCCLCommCacheSize() {
return devNCCLCommMap_.size();
}
void simulate_error() {
simulate_error_ = true;
}
void reset_error() {
simulate_error_ = false;
}
private:
bool simulate_error_;
};
class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
public:
WorkNCCLTimedoutErrors(
const std::vector<at::Device>& devices,
bool set_timedout_error,
int rank,
c10d::OpType opType,
uint64_t seq)
: WorkNCCL(devices, rank, opType, seq),
set_timedout_error_(set_timedout_error) {}
private:
bool isCompleted() override {
if (set_timedout_error_) {
return false;
}
return c10d::ProcessGroupNCCL::WorkNCCL::isCompleted();
}
private:
bool set_timedout_error_;
};
class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors {
public:
ProcessGroupNCCLTimedOutErrors(
const c10::intrusive_ptr<c10d::Store>& store,
int rank,
int size,
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
: ProcessGroupNCCLSimulateErrors(store, rank, size, opts),
set_timedout_error_(false) {}
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
std::vector<at::Device> devices,
int rank,
c10d::OpType opType,
const char* profilingTitle, const c10::optional<std::vector<at::Tensor>>& inputs = c10::nullopt) override {
return c10::make_intrusive<WorkNCCLTimedoutErrors>(
devices, set_timedout_error_, rank, opType, seq_);
}
void set_timedout_error() {
set_timedout_error_ = true;
}
void reset_timedout_error() {
set_timedout_error_ = false;
}
private:
bool set_timedout_error_;
};
class ProcessGroupNCCLErrorsTest : public ::testing::Test {
protected:
bool skipTest() {
if (cudaNumDevices() == 0) {
LOG(INFO) << "Skipping test since CUDA is not available";
return true;
}
#ifdef USE_C10D_NCCL
if (torch::cuda::nccl::version() < kNcclErrorHandlingVersion) {
LOG(INFO) << "Skipping test since NCCL version is too old";
return true;
}
#endif
return false;
}
void SetUp() override {
size_t numDevices = cudaNumDevices();
TemporaryFile file;
store_ = c10::make_intrusive<::c10d::FileStore>(file.path, 1);
at::cuda::OptionalCUDAGuard deviceGuard;
tensors_.resize(numDevices);
for (const auto i : c10::irange(numDevices)) {
deviceGuard.set_index(i);
tensors_[i] = at::ones({3, 3}, at::kCUDA);
}
}
void TearDown() override {
ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "0", 1) == 0);
}
std::vector<at::Tensor> tensors_;
c10::intrusive_ptr<::c10d::FileStore> store_;
};
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) {
if (skipTest()) {
return;
}
ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0);
auto options = c10d::ProcessGroupNCCL::Options::create();
options->timeout = std::chrono::milliseconds(1000);
ProcessGroupNCCLSimulateErrors pg(
store_, 0, 1, options);
auto work = pg.allreduce(tensors_);
work->wait();
EXPECT_TRUE(work->isSuccess());
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
// Now run all reduce with errors.
pg.simulate_error();
work = pg.allreduce(tensors_);
EXPECT_THROW(work->wait(), std::runtime_error);
// Verify the work item failed.
EXPECT_TRUE(work->isCompleted());
EXPECT_FALSE(work->isSuccess());
EXPECT_THROW(work->wait(), std::runtime_error);
// Communicators might be aborted here, further operations would fail.
}
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) {
if (skipTest()) {
return;
}
ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0);
auto options = c10d::ProcessGroupNCCL::Options::create();
options->timeout = std::chrono::milliseconds(3000);
ProcessGroupNCCLTimedOutErrors pg(
store_, 0, 1, options);
auto work = pg.allreduce(tensors_);
work->wait();
EXPECT_TRUE(work->isSuccess());
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
// Now run all reduce with errors.
pg.set_timedout_error();
work = pg.allreduce(tensors_);
EXPECT_THROW(work->wait(), c10::Error);
// Communicators might be aborted here, further operations would fail.
}
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) {
if (skipTest()) {
return;
}
auto options = c10d::ProcessGroupNCCL::Options::create();
options->timeout = std::chrono::milliseconds(3000);
ProcessGroupNCCLSimulateErrors pg(
store_, 0, 1, options);
auto work = pg.allreduce(tensors_);
pg.barrier()->wait();
EXPECT_TRUE(work->isSuccess());
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
// Now run all reduce with errors.
pg.simulate_error();
work = pg.allreduce(tensors_);
// Should not throw exceptions.
work->wait();
pg.barrier()->wait();
// Verify the work item failed.
EXPECT_TRUE(work->isCompleted());
EXPECT_FALSE(work->isSuccess());
// Communicators might be aborted here, further operations would fail.
}