pytorch/test/cpp/c10d/TCPStoreTest.cpp

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

235 lines
7.4 KiB
C++
Raw Normal View History

#include <c10/util/irange.h>
#include "StoreTestCommon.hpp"
#include <cstdlib>
#include <future>
2018-05-23 18:26:35 +00:00
#include <iostream>
#include <system_error>
2018-05-23 18:26:35 +00:00
#include <thread>
#include <gtest/gtest.h>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
constexpr int64_t kShortStoreTimeoutMillis = 100;
constexpr int defaultTimeout = 20;
TCPStore add watchKey method and new listener thread (#54264) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54264 **Changes** - Creates new listener thread on each client to run the callback - Create new class which listener thread and master thread derive from, this class is used to handle shut down and clean up of the thread in windows and linux - Add watchKey method and update any functions that changes the key value. **Background** This PR adds functionality to TCPStore to allow users to watch a key and execute a callback on key change. It introduces this a new watchKey() API: `TCPStore::watchKey(const std::string& key, std::function<void(std::string, std::string)> callback)` which has parameters `key` and `callback(old_key, new_key)` to run on key change. Since current methods are blocking, for example in`TCPStore::get()` a worker will send a "get key" request to the master -> wait for a response back -> then exit the function and return the value to user, we need a non-blocking, asynchronous way to execute the callback whenever a key changes. This is done by creating a new listener thread on each client which the master can communicate with. Right now, the API is C++ only and only for TCPStore, the internal use case is for elastic RPC. We will have an internal key such as `_NumNodes` and all nodes in the elastic RPC group will watch this key. When a node leaves, this key will be updated and each node will execute a callback to clean up Autograd context and RRef context. Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D27709912 Pulled By: H-Huang fbshipit-source-id: 619aa3b2a8eb23f4be5f5736efdcca6c175aadf3
2021-04-14 20:19:42 +00:00
c10::intrusive_ptr<c10d::TCPStore> _createServer(
bool useLibUV,
int numWorkers = 1,
int timeout = defaultTimeout) {
return c10::make_intrusive<c10d::TCPStore>(
"127.0.0.1",
c10d::TCPStoreOptions{
/* port */ 0,
/* isServer */ true,
numWorkers,
/* waitWorkers */ false,
/* timeout */ std::chrono::seconds(timeout),
/* multiTenant */ false,
/* masterListenFd */ c10::nullopt,
/* useLibUV*/ useLibUV});
}
// Different ports for different tests.
void testHelper(bool useLibUV, const std::string& prefix = "") {
constexpr auto numThreads = 16;
constexpr auto numWorkers = numThreads + 1;
auto serverTCPStore = _createServer(useLibUV, numWorkers);
auto serverStore =
c10::make_intrusive<c10d::PrefixStore>(prefix, serverTCPStore);
// server store
auto serverThread = std::thread([&serverStore, &serverTCPStore] {
// Wait for all workers to join.
serverTCPStore->waitForWorkers();
// Basic set/get on the server store
c10d::test::set(*serverStore, "key0", "value0");
c10d::test::set(*serverStore, "key1", "value1");
c10d::test::set(*serverStore, "key2", "value2");
c10d::test::check(*serverStore, "key0", "value0");
c10d::test::check(*serverStore, "key1", "value1");
c10d::test::check(*serverStore, "key2", "value2");
serverStore->add("counter", 1);
auto numKeys = serverStore->getNumKeys();
// We expect 5 keys since 3 are added above, 'counter' is added by the
// helper thread, and the init key to coordinate workers.
EXPECT_EQ(numKeys, 5);
// Check compareSet, does not check return value
c10d::test::compareSet(
*serverStore, "key0", "wrongExpectedValue", "newValue");
c10d::test::check(*serverStore, "key0", "value0");
c10d::test::compareSet(*serverStore, "key0", "value0", "newValue");
c10d::test::check(*serverStore, "key0", "newValue");
auto delSuccess = serverStore->deleteKey("key0");
// Ensure that the key was successfully deleted
EXPECT_TRUE(delSuccess);
auto delFailure = serverStore->deleteKey("badKeyName");
// The key was not in the store so the delete operation should have failed
// and returned false.
EXPECT_FALSE(delFailure);
numKeys = serverStore->getNumKeys();
EXPECT_EQ(numKeys, 4);
auto timeout = std::chrono::milliseconds(kShortStoreTimeoutMillis);
serverStore->setTimeout(timeout);
EXPECT_THROW(serverStore->get("key0"), c10::Error);
});
// Hammer on TCPStore
std::vector<std::thread> threads;
constexpr auto numIterations = 1000;
c10d::test::Semaphore sem1, sem2;
c10d::TCPStoreOptions opts{};
opts.port = serverTCPStore->getPort();
opts.numWorkers = numWorkers;
// Each thread will have a client store to send/recv data
std::vector<c10::intrusive_ptr<c10d::TCPStore>> clientTCPStores;
std::vector<c10::intrusive_ptr<c10d::PrefixStore>> clientStores;
for (const auto i : c10::irange(numThreads)) {
clientTCPStores.push_back(
c10::make_intrusive<c10d::TCPStore>("127.0.0.1", opts));
clientStores.push_back(
c10::make_intrusive<c10d::PrefixStore>(prefix, clientTCPStores[i]));
}
std::string expectedCounterRes =
std::to_string(numThreads * numIterations + 1);
for (const auto i : c10::irange(numThreads)) {
threads.emplace_back(
std::thread([=, &sem1, &sem2, &clientStores, &expectedCounterRes] {
for (C10_UNUSED const auto j : c10::irange(numIterations)) {
clientStores[i]->add("counter", 1);
}
// Let each thread set and get key on its client store
std::string key = "thread_" + std::to_string(i);
for (const auto j : c10::irange(numIterations)) {
std::string val = "thread_val_" + std::to_string(j);
c10d::test::set(*clientStores[i], key, val);
c10d::test::check(*clientStores[i], key, val);
}
sem1.post();
sem2.wait();
// Check the counter results
c10d::test::check(*clientStores[i], "counter", expectedCounterRes);
// Now check other threads' written data
for (const auto j : c10::irange(numThreads)) {
if (j == i) {
continue;
}
std::string key = "thread_" + std::to_string(i);
std::string val = "thread_val_" + std::to_string(numIterations - 1);
c10d::test::check(*clientStores[i], key, val);
}
}));
}
sem1.wait(numThreads);
sem2.post(numThreads);
for (auto& thread : threads) {
thread.join();
}
serverThread.join();
// Clear the store to test that client disconnect won't shutdown the store
clientStores.clear();
clientTCPStores.clear();
// Check that the counter has the expected value
c10d::test::check(*serverStore, "counter", expectedCounterRes);
// Check that each threads' written data from the main thread
for (const auto i : c10::irange(numThreads)) {
std::string key = "thread_" + std::to_string(i);
std::string val = "thread_val_" + std::to_string(numIterations - 1);
c10d::test::check(*serverStore, key, val);
}
}
TEST(TCPStoreTest, testHelper) {
testHelper(false);
}
TEST(TCPStoreTest, testHelperUV) {
testHelper(true);
}
TEST(TCPStoreTest, testHelperPrefix) {
testHelper(false, "testPrefix");
}
TEST(TCPStoreTest, testHelperPrefixUV) {
testHelper(true, "testPrefix");
}
TEST(TCPStoreTest, testCleanShutdown) {
int numWorkers = 2;
auto serverTCPStore = std::make_unique<c10d::TCPStore>(
"127.0.0.1",
0,
numWorkers,
true,
std::chrono::seconds(defaultTimeout),
/* wait */ false);
c10d::test::set(*serverTCPStore, "key", "val");
auto clientTCPStore = c10::make_intrusive<c10d::TCPStore>(
"127.0.0.1",
c10d::TCPStoreOptions{
/* port */ serverTCPStore->getPort(),
/* isServer */ false,
numWorkers,
/* waitWorkers */ false,
/* timeout */ std::chrono::seconds(defaultTimeout)});
clientTCPStore->get("key");
auto clientThread = std::thread([&clientTCPStore] {
EXPECT_THROW(clientTCPStore->get("invalid_key"), c10::DistNetworkError);
});
// start server shutdown during a client request
serverTCPStore = nullptr;
clientThread.join();
}
void testMultiTenantStores(bool libUV) {
c10d::TCPStoreOptions opts{};
opts.isServer = true;
opts.multiTenant = true;
opts.useLibUV = libUV;
// Construct two server stores on the same port.
auto store1 = c10::make_intrusive<c10d::TCPStore>("localhost", opts);
auto store2 = c10::make_intrusive<c10d::TCPStore>("localhost", opts);
// Assert that the two stores share the same server.
c10d::test::set(*store1, "key0", "value0");
c10d::test::check(*store2, "key0", "value0");
// Dispose the second instance and assert that the server is still alive.
store2.reset();
c10d::test::set(*store1, "key0", "value0");
c10d::test::check(*store1, "key0", "value0");
}
TEST(TCPStoreTest, testMultiTenantStores) {
testMultiTenantStores(false);
}
TEST(TCPStoreTest, testMultiTenantStoresUV) {
testMultiTenantStores(true);
}