mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/36893 Adding an end to end test for running a simple training loop in C++ for the distributed RPC framework. The goal of this change is to enable LeakSanitizer and potentially catch memory leaks in the Future. Enabling LSAN with python multiprocessing is tricky and we haven't found a solution for this. As a result, adding a C++ test that triggers most of the critical codepaths would be good for now. As an example, this unit test would've caught the memory leak fixed by: https://github.com/pytorch/pytorch/pull/31030 ghstack-source-id: 107781167 Test Plan: 1) Verify the test catches memory leaks. 2) waitforbuildbot Reviewed By: mrshenli Differential Revision: D21112208 fbshipit-source-id: 4eb2a6b409253108f6b6e14352e593d250c7a64d
46 lines
1.3 KiB
C++
46 lines
1.3 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include "e2e_test_base.h"
|
|
|
|
#include <c10d/ProcessGroupGloo.hpp>
|
|
#include <torch/csrc/distributed/rpc/process_group_agent.h>
|
|
#include <torch/csrc/distributed/rpc/request_callback_no_python.h>
|
|
#include <torch/torch.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
using namespace torch::distributed::autograd;
|
|
|
|
class TestE2EProcessGroup : public TestE2EBase {
|
|
protected:
|
|
void buildRpcAgent() override {
|
|
c10d::ProcessGroupGloo::Options options;
|
|
options.devices.push_back(
|
|
::c10d::ProcessGroupGloo::createDeviceForHostname(serverAddress));
|
|
std::chrono::milliseconds rpcTimeout(30000);
|
|
|
|
// Initialize server rpc agent.
|
|
auto pg =
|
|
std::make_shared<c10d::ProcessGroupGloo>(store, 0, numWorkers, options);
|
|
|
|
rpcAgent = std::make_shared<ProcessGroupAgent>(
|
|
"worker",
|
|
pg,
|
|
std::max(16U, std::thread::hardware_concurrency()),
|
|
rpcTimeout,
|
|
std::make_unique<RequestCallbackNoPython>());
|
|
}
|
|
};
|
|
|
|
// End to end training loop test in C++ so that we can run LSAN on this test to
|
|
// catch memory leaks. Enabling LSAN with python multiprocessing has been
|
|
// challenging and we don't have a good solution yet.
|
|
TEST_F(TestE2EProcessGroup, TestTrainingLoop) {
|
|
runTrainingLoop();
|
|
}
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|