mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Distributed] [10/N] Fix clang-tidy warnings in torch/csrc/distributed/c10d/control_plane (#131671)
Follows #130109 Pull Request resolved: https://github.com/pytorch/pytorch/pull/131671 Approved by: https://github.com/zou3519
This commit is contained in:
parent
2d7c135757
commit
62704db5c3
4 changed files with 15 additions and 24 deletions
|
|
@ -4,9 +4,9 @@
|
|||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
|
||||
namespace c10d {
|
||||
namespace control_plane {
|
||||
namespace c10d::control_plane {
|
||||
|
||||
namespace {
|
||||
|
||||
|
|
@ -20,7 +20,7 @@ class HandlerRegistry {
|
|||
fmt::format("Handler {} already registered", name));
|
||||
}
|
||||
|
||||
handlers_[name] = f;
|
||||
handlers_[name] = std::move(f);
|
||||
}
|
||||
|
||||
HandlerFunc getHandler(const std::string& name) {
|
||||
|
|
@ -38,6 +38,7 @@ class HandlerRegistry {
|
|||
std::shared_lock<std::shared_mutex> lock(handlersMutex_);
|
||||
|
||||
std::vector<std::string> names;
|
||||
names.reserve(handlers_.size());
|
||||
for (const auto& [name, _] : handlers_) {
|
||||
names.push_back(name);
|
||||
}
|
||||
|
|
@ -62,7 +63,7 @@ RegisterHandler pingHandler{"ping", [](const Request&, Response& res) {
|
|||
} // namespace
|
||||
|
||||
void registerHandler(const std::string& name, HandlerFunc f) {
|
||||
return getHandlerRegistry().registerHandler(name, f);
|
||||
return getHandlerRegistry().registerHandler(name, std::move(f));
|
||||
}
|
||||
|
||||
HandlerFunc getHandler(const std::string& name) {
|
||||
|
|
@ -73,5 +74,4 @@ std::vector<std::string> getHandlerNames() {
|
|||
return getHandlerRegistry().getHandlerNames();
|
||||
}
|
||||
|
||||
} // namespace control_plane
|
||||
} // namespace c10d
|
||||
} // namespace c10d::control_plane
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@
|
|||
#include <functional>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
namespace c10d {
|
||||
namespace control_plane {
|
||||
namespace c10d::control_plane {
|
||||
|
||||
// Request represents a request to the handler. This conceptually maps to an
|
||||
// HTTP request but could be called via other transports.
|
||||
|
|
@ -56,7 +56,7 @@ TORCH_API std::vector<std::string> getHandlerNames();
|
|||
class TORCH_API RegisterHandler {
|
||||
public:
|
||||
RegisterHandler(const std::string& name, HandlerFunc f) {
|
||||
registerHandler(name, f);
|
||||
registerHandler(name, std::move(f));
|
||||
}
|
||||
|
||||
// disable move, copy
|
||||
|
|
@ -66,5 +66,4 @@ class TORCH_API RegisterHandler {
|
|||
RegisterHandler& operator=(RegisterHandler&&) = delete;
|
||||
};
|
||||
|
||||
} // namespace control_plane
|
||||
} // namespace c10d
|
||||
} // namespace c10d::control_plane
|
||||
|
|
|
|||
|
|
@ -1,8 +1,5 @@
|
|||
#include <filesystem>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <ATen/core/interned_strings.h>
|
||||
|
|
@ -11,8 +8,7 @@
|
|||
#include <torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp>
|
||||
#include <torch/csrc/distributed/c10d/logging.h>
|
||||
|
||||
namespace c10d {
|
||||
namespace control_plane {
|
||||
namespace c10d::control_plane {
|
||||
|
||||
namespace {
|
||||
class RequestImpl : public Request {
|
||||
|
|
@ -192,5 +188,4 @@ WorkerServer::~WorkerServer() {
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace control_plane
|
||||
} // namespace c10d
|
||||
} // namespace c10d::control_plane
|
||||
|
|
|
|||
|
|
@ -2,20 +2,18 @@
|
|||
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <httplib.h>
|
||||
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
|
||||
|
||||
namespace c10d {
|
||||
namespace control_plane {
|
||||
namespace c10d::control_plane {
|
||||
|
||||
class TORCH_API WorkerServer : public c10::intrusive_ptr_target {
|
||||
public:
|
||||
WorkerServer(const std::string& hostOrFile, int port = -1);
|
||||
~WorkerServer();
|
||||
~WorkerServer() override;
|
||||
|
||||
void shutdown();
|
||||
|
||||
|
|
@ -24,5 +22,4 @@ class TORCH_API WorkerServer : public c10::intrusive_ptr_target {
|
|||
std::thread serverThread_;
|
||||
};
|
||||
|
||||
} // namespace control_plane
|
||||
} // namespace c10d
|
||||
} // namespace c10d::control_plane
|
||||
|
|
|
|||
Loading…
Reference in a new issue