From 3e81c21024923bd042a1436e458dc1758e380826 Mon Sep 17 00:00:00 2001 From: wuyangyu Date: Thu, 15 Aug 2024 11:08:32 +0800 Subject: [PATCH 01/14] proxy1.11 --- torch_npu/csrc/distributed/LocalClient.cpp | 136 ++++++++++++++ torch_npu/csrc/distributed/LocalClient.hpp | 23 +++ .../csrc/distributed/ParallelTcpStore.cpp | 33 ++-- torch_npu/csrc/distributed/local_server.cpp | 167 ++++++++++++++++++ torch_npu/csrc/distributed/local_server.hpp | 62 +++++++ torch_npu/csrc/distributed/proxy.cpp | 47 +++++ torch_npu/csrc/distributed/proxy.hpp | 25 +++ 7 files changed, 476 insertions(+), 17 deletions(-) create mode 100644 torch_npu/csrc/distributed/LocalClient.cpp create mode 100644 torch_npu/csrc/distributed/LocalClient.hpp create mode 100644 torch_npu/csrc/distributed/local_server.cpp create mode 100644 torch_npu/csrc/distributed/local_server.hpp create mode 100644 torch_npu/csrc/distributed/proxy.cpp create mode 100644 torch_npu/csrc/distributed/proxy.hpp diff --git a/torch_npu/csrc/distributed/LocalClient.cpp b/torch_npu/csrc/distributed/LocalClient.cpp new file mode 100644 index 0000000000..1560f0d12d --- /dev/null +++ b/torch_npu/csrc/distributed/LocalClient.cpp @@ -0,0 +1,136 @@ +#include +#include +#include + +#include +#include +#include +#include + +#include "c10/util/Logging.h" +#include "LocalClient.hpp" + +namespace c10d { +namespace pta { +static constexpr uint32_t READ_BUF_SZ = 256; + +LocalClient::LocalClient(std::string socketPath) noexcept + : socketPath_{ std::move(socketPath) }, socketFd_{ -1 } +{} + +int LocalClient::Connect() noexcept +{ + socketFd_ = socket(AF_UNIX, SOCK_STREAM, 0); + if (socketFd_ < 0) { + LOG(ERROR) << "create local client socket failed " << errno << " : " << strerror(errno); + return -1; + } + + struct sockaddr_un servAddr {}; + servAddr.sun_family = AF_UNIX; + strncpy(servAddr.sun_path, socketPath_.c_str(), sizeof(servAddr.sun_path) - 1); + + int lastError = 0; + auto endTime = std::chrono::steady_clock::now() + std::chrono::minutes(1); + while (std::chrono::steady_clock::now() < endTime) { + auto ret = connect(socketFd_, reinterpret_cast(&servAddr), sizeof(servAddr)); + if (ret == 0) { + return 0; + } + + if (errno != lastError) { + LOG(ERROR) << "connect socket to local server(" << socketPath_ << ") failed " << errno << " : " << + strerror(errno); + lastError = errno; + } + + if (errno == ENOENT || errno == ECONNREFUSED) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + continue; + } + + break; + } + + return -1; +} + +int LocalClient::Close() noexcept +{ + auto ret = close(socketFd_); + if (ret == 0) { + socketFd_ = -1; + return 0; + } + + LOG(ERROR) << "close socket to local server(" << socketPath_ << ") failed " << errno << " : " << + strerror(errno); + return ret; +} + +int LocalClient::SyncCall(const StoreMessage &request, StoreMessage &response) noexcept +{ + auto packedRequest = StoreMessagePacker::Pack(request); + auto ret = write(socketFd_, packedRequest.data(), packedRequest.size()); + if (ret < 0) { + LOG(ERROR) << "write data to local server(" << socketPath_ << ") failed " << errno << " : " << + strerror(errno); + return -1; + } + + uint8_t buffer[READ_BUF_SZ]; + std::vector responseBuf; + + bool finished = false; + int result = -1; + while (!finished) { + do { + ret = read(socketFd_, buffer, READ_BUF_SZ); + if (ret < 0) { + LOG(ERROR) << "read data from local server(" << socketPath_ << ") failed " << errno << " : " << + strerror(errno); + return -1; + } + + responseBuf.insert(responseBuf.end(), buffer, buffer + ret); + } while (!StoreMessagePacker::Full(responseBuf)); + + auto unpackRet = StoreMessagePacker::Unpack(responseBuf, response); + if (unpackRet < 0L) { + LOG(ERROR) << "unpack response data from local server(" << socketPath_ << ") failed " << unpackRet; + finished = true; + result = -1; + continue; + } + + if (response.mt == request.mt) { + finished = true; + result = 0; + continue; + } + + responseBuf.erase(responseBuf.begin(), responseBuf.begin() + unpackRet); + } + + return result; +} + +int LocalClient::SetReceiveTimeout(const std::chrono::milliseconds &value) const noexcept +{ + if (value == std::chrono::milliseconds::zero()) { + return 0; + } + struct timeval timeoutTV = { + .tv_sec = static_cast(value.count() / 1000), + .tv_usec = static_cast((value.count() % 1000) * 1000) + }; + + auto ret = setsockopt(socketFd_, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&timeoutTV), sizeof(timeoutTV)); + if (ret != 0) { + LOG(ERROR) << "set local connection receive timeout failed: " << errno << " : " << strerror(errno); + } + + return ret; +} +} // pta +} // c10d \ No newline at end of file diff --git a/torch_npu/csrc/distributed/LocalClient.hpp b/torch_npu/csrc/distributed/LocalClient.hpp new file mode 100644 index 0000000000..4464ebcb37 --- /dev/null +++ b/torch_npu/csrc/distributed/LocalClient.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include +#include + +#include "StoreMessagePacker.hpp" + +namespace c10d { +namespace pta { +class LocalClient { +public: + explicit LocalClient(std::string socketPath) noexcept; + int Connect() noexcept; + int Close() noexcept; + int SyncCall(const StoreMessage &request, StoreMessage &response) noexcept; + int SetReceiveTimeout(const std::chrono::milliseconds &value) const noexcept; + +private: + const std::string socketPath_; + int socketFd_; +}; +} // pta +} // c10d \ No newline at end of file diff --git a/torch_npu/csrc/distributed/ParallelTcpStore.cpp b/torch_npu/csrc/distributed/ParallelTcpStore.cpp index d2acf0df9b..1b3943e1e9 100644 --- a/torch_npu/csrc/distributed/ParallelTcpStore.cpp +++ b/torch_npu/csrc/distributed/ParallelTcpStore.cpp @@ -252,31 +252,30 @@ bool ParallelStoreServer::CheckAllKeysExistInLock(const std::vector std::mutex ParallelTcpStore::cacheServerMutex_; std::unordered_map> ParallelTcpStore::cachedServers_; -ParallelTcpStore::ParallelTcpStore(const std::string &host, const c10d::TCPStoreOptions &opts) - : Store(opts.timeout), client_{ host, opts.port } +ParallelTcpStore::ParallelTcpStore(const std::string& host, const TCPStoreOptions& opts) + : Store(opts.timeout) { - if (opts.isServer) { + if (opts.is_server) { if (opts.multiTenant) { server_ = GetSharedServer(initKey_, opts.port, opts.numWorkers); } else { server_ = std::make_shared(initKey_, opts.port, opts.numWorkers); } + } + // 检查环境变量 local_rank + char* local_rank_env = std::getenv("LOCAL_RANK"); + + if (local_rank_env == nullptr) { + // 如果 LOCAL_RANK 环境变量不存在,则为 Proxy + proxy_ = std::make_unique("/tmp/torch_dist_store", host, opts.port); + proxy_->Start(); + } else { + // 如果 LOCAL_RANK 环境变量存在,则为 Worker + localClient_ = std::make_unique("/tmp/torch_dist_store"); + localClient_->Connect(); } - if (client_.Connect() != 0) { - throw std::runtime_error{ std::string("connect tcp client to server(") - .append(host) - .append(":") - .append(std::to_string(opts.port)) - .append(" failed.") }; - } - - if (opts.waitWorkers) { - IncreaseKey(initKey_, 1); - if (opts.isServer) { - server_->WaitWorkers(timeout_); - } - } + // ... 其他初始化逻辑 ... } ParallelTcpStore::~ParallelTcpStore() noexcept diff --git a/torch_npu/csrc/distributed/local_server.cpp b/torch_npu/csrc/distributed/local_server.cpp new file mode 100644 index 0000000000..e267e3ade0 --- /dev/null +++ b/torch_npu/csrc/distributed/local_server.cpp @@ -0,0 +1,167 @@ +// LocalServer.cpp +#include "LocalServer.hpp" +#include +#include +#include +#include +#include +#include "c10/util/Logging.h" + +namespace c10d { +namespace pta { + +LocalServer::LocalServer(std::string socketPath, c10::optional numWorkers) + : numWorkers_(numWorkers), socketPath_(std::move(socketPath)), serverSocket_(-1), running_(false) +{ + InitializeHandlers(); +} + +LocalServer::~LocalServer() noexcept +{ + Stop(); +} + +void LocalServer::Start() +{ + serverSocket_ = socket(AF_UNIX, SOCK_STREAM, 0); + if (serverSocket_ < 0) { + throw std::runtime_error("Failed to create Unix domain socket"); + } + + struct sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + strncpy(addr.sun_path, socketPath_.c_str(), sizeof(addr.sun_path) - 1); + + unlink(socketPath_.c_str()); // Remove any existing socket file + + if (bind(serverSocket_, (struct sockaddr*)&addr, sizeof(addr)) < 0) { + close(serverSocket_); + throw std::runtime_error("Failed to bind Unix domain socket"); + } + + if (listen(serverSocket_, 5) < 0) { + close(serverSocket_); + throw std::runtime_error("Failed to listen on Unix domain socket"); + } + + running_ = true; + serverThread_ = std::thread(&LocalServer::RunServer, this); +} + +void LocalServer::Stop() +{ + running_ = false; + if (serverThread_.joinable()) { + serverThread_.join(); + } + if (serverSocket_ >= 0) { + close(serverSocket_); + serverSocket_ = -1; + } + unlink(socketPath_.c_str()); +} + +void LocalServer::WaitWorkers(const std::chrono::milliseconds &timeout) noexcept +{ + if (numWorkers_ == c10::nullopt) { + return; + } + + const auto start = std::chrono::steady_clock::now(); + while (!workersReady_) { + std::unique_lock lockGuard{ initWaitMutex_ }; + if (timeout == Store::kNoTimeout) { + initWaitCond_.wait(lockGuard); + } else { + initWaitCond_.wait_until(lockGuard, start + timeout); + } + } +} + +void LocalServer::RunServer() +{ + while (running_) { + int clientSocket = accept(serverSocket_, nullptr, nullptr); + if (clientSocket < 0) { + if (errno != EINTR) { + LOG(ERROR) << "Failed to accept client connection: " << strerror(errno); + } + continue; + } + + std::thread clientThread(&LocalServer::HandleClient, this, clientSocket); + clientThread.detach(); + } +} + +void LocalServer::HandleClient(int clientSocket) +{ + std::vector buffer(1024); + while (running_) { + ssize_t bytesRead = recv(clientSocket, buffer.data(), buffer.size(), 0); + if (bytesRead <= 0) { + break; + } + + StoreMessage request; + auto unpackResult = StoreMessagePacker::Unpack(buffer, request); + if (unpackResult <= 0) { + LOG(ERROR) << "Failed to unpack client request"; + continue; + } + + StoreMessage response = ProcessRequest(clientSocket, request); + auto packedResponse = StoreMessagePacker::Pack(response); + send(clientSocket, packedResponse.data(), packedResponse.size(), 0); + } + close(clientSocket); +} + +pta::StoreMessage LocalServer::ProcessRequest(int fd, const pta::StoreMessage &request) noexcept +{ + auto pos = requestHandlers_.find(request.mt); + if (pos != requestHandlers_.end()) { + auto response = pos->second(fd, request); + + // 在成功处理请求后通知 Proxy + if (proxyCallback_) { + proxyCallback_(request); + } + + return response; + } + + LOG(ERROR) << "unsupported message type " << static_cast(request.mt); + return request; +} + +// Implement other ProcessXXXRequest methods similar to ParallelStoreServer... + +void LocalServer::InitializeHandlers() noexcept +{ + requestHandlers_.emplace(pta::MessageType::SET, + [this](int fd, const pta::StoreMessage &req) { return ProcessSetRequest(fd, req); }); + requestHandlers_.emplace(pta::MessageType::COMPARE_SET, + [this](int fd, const pta::StoreMessage &req) { return ProcessCompareSetRequest(fd, req); }); + requestHandlers_.emplace(pta::MessageType::GET, + [this](int fd, const pta::StoreMessage &req) { return ProcessGetRequest(fd, req); }); + requestHandlers_.emplace(pta::MessageType::ADD, + [this](int fd, const pta::StoreMessage &req) { return ProcessAddRequest(fd, req); }); + requestHandlers_.emplace(pta::MessageType::CHECK, + [this](int fd, const pta::StoreMessage &req) { return ProcessCheckRequest(fd, req); }); + requestHandlers_.emplace(pta::MessageType::WAIT, + [this](int fd, const pta::StoreMessage &req) { return ProcessWaitKeysRequest(fd, req); }); + requestHandlers_.emplace(pta::MessageType::GET_NUM_KEYS, + [this](int fd, const pta::StoreMessage &req) { return ProcessGetNumKeyRequest(fd, req); }); + requestHandlers_.emplace(pta::MessageType::DELETE_KEY, + [this](int fd, const pta::StoreMessage &req) { return ProcessDeleteRequest(fd, req); }); +} + +bool LocalServer::CheckAllKeysExistInLock(const std::vector &keys) noexcept +{ + return std::all_of(keys.begin(), keys.end(), [this](const std::string &key) { return keyStore_.count(key) > 0; }); +} + +} // namespace pta +} // namespace c10d \ No newline at end of file diff --git a/torch_npu/csrc/distributed/local_server.hpp b/torch_npu/csrc/distributed/local_server.hpp new file mode 100644 index 0000000000..f22ef7fdb8 --- /dev/null +++ b/torch_npu/csrc/distributed/local_server.hpp @@ -0,0 +1,62 @@ +// LocalServer.hpp +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "StoreMessage.hpp" + +namespace c10d { +namespace pta { + +class LocalServer { +public: +using ProxyCallback = std::function; + LocalServer(std::string socketPath, c10::optional numWorkers); + virtual ~LocalServer() noexcept; + void Start(); + void Stop(); + void WaitWorkers(const std::chrono::milliseconds &timeout) noexcept; + void setProxyCallback(ProxyCallback callback) { + proxyCallback_ = std::move(callback); + } + +private: + + void RunServer(); + void HandleClient(int clientSocket); + pta::StoreMessage ProcessRequest(int fd, const pta::StoreMessage &request) noexcept; + pta::StoreMessage ProcessGetRequest(int fd, const pta::StoreMessage &request) noexcept; + pta::StoreMessage ProcessSetRequest(int fd, const pta::StoreMessage &request) noexcept; + pta::StoreMessage ProcessAddRequest(int fd, const pta::StoreMessage &request) noexcept; + pta::StoreMessage ProcessCheckRequest(int fd, const pta::StoreMessage &request) noexcept; + pta::StoreMessage ProcessDeleteRequest(int fd, const pta::StoreMessage &request) noexcept; + pta::StoreMessage ProcessCompareSetRequest(int fd, const pta::StoreMessage &request) noexcept; + pta::StoreMessage ProcessGetNumKeyRequest(int fd, const pta::StoreMessage &request) noexcept; + pta::StoreMessage ProcessWaitKeysRequest(int fd, const pta::StoreMessage &request) noexcept; + void InitializeHandlers() noexcept; + bool CheckAllKeysExistInLock(const std::vector &keys) noexcept; + + using RequestHandler = std::function; + std::unordered_map requestHandlers_; + std::unordered_map> keyStore_; + SpinLock serverLock_; + std::mutex initWaitMutex_; + std::condition_variable initWaitCond_; + std::atomic workersReady_{ false }; + const c10::optional numWorkers_; + const std::string socketPath_; + const std::string initKey_ = "init/"; + const std::string keyPrefix_ = "/"; + int serverSocket_; + std::atomic running_; + std::thread serverThread_; + ProxyCallback proxyCallback_; +}; + +} // namespace pta +} // namespace c10d \ No newline at end of file diff --git a/torch_npu/csrc/distributed/proxy.cpp b/torch_npu/csrc/distributed/proxy.cpp new file mode 100644 index 0000000000..f5de59a5bb --- /dev/null +++ b/torch_npu/csrc/distributed/proxy.cpp @@ -0,0 +1,47 @@ +#include "Proxy.hpp" +#include "c10/util/Exception.h" +#include + +namespace c10d { +namespace pta { + +Proxy::Proxy(const std::string& localSocketPath, const std::string& tcpHost, uint16_t tcpPort) + : localServer_(std::make_unique(localSocketPath, c10::nullopt)) +{ + localServer_->setProxyCallback([this](const StoreMessage& msg) { + this->HandleLocalServerMessage(msg); + }); + + if (!tcpHost.empty() && tcpPort != 0) { + tcpClient_ = std::make_unique(tcpHost, tcpPort); + } +} + +void Proxy::Start() +{ + localServer_->Start(); + if (tcpClient_) { + tcpClient_->Connect(); + } +} + +void Proxy::Stop() +{ + localServer_->Stop(); + if (tcpClient_) { + tcpClient_->Close(); + } +} + +void Proxy::HandleLocalServerMessage(const StoreMessage& message) +{ + if (tcpClient_) { + StoreMessage response; + tcpClient_->SyncCall(message, response); + // 如果需要,可以将响应发送回 LocalServer + // localServer_->HandleResponse(response); + } +} + +} // namespace pta +} // namespace c10d \ No newline at end of file diff --git a/torch_npu/csrc/distributed/proxy.hpp b/torch_npu/csrc/distributed/proxy.hpp new file mode 100644 index 0000000000..e9975d9fb3 --- /dev/null +++ b/torch_npu/csrc/distributed/proxy.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include +#include +#include "LocalServer.hpp" +#include "TcpClient.hpp" + +namespace c10d { +namespace pta { + +class Proxy { +public: + Proxy(const std::string& localSocketPath, const std::string& tcpHost, uint16_t tcpPort); + void Start(); + void Stop(); + +private: + void HandleLocalServerMessage(const StoreMessage& message); + + std::unique_ptr localServer_; + std::unique_ptr tcpClient_; +}; + +} // namespace pta +} // namespace c10d \ No newline at end of file -- Gitee From 3b9c46f2f31fe60e2e289e20ede22ff8e6dbce0b Mon Sep 17 00:00:00 2001 From: wuyangyu Date: Thu, 15 Aug 2024 11:08:32 +0800 Subject: [PATCH 02/14] proxy1.11 --- .../distributed/{TcpClient.cpp => Client.cpp} | 75 +++++++- .../distributed/{TcpClient.hpp => Client.hpp} | 8 +- torch_npu/csrc/distributed/LocalClient.cpp | 136 -------------- torch_npu/csrc/distributed/LocalClient.hpp | 23 --- .../csrc/distributed/ParallelTcpServer.cpp | 89 ++++++++++ .../csrc/distributed/ParallelTcpServer.hpp | 6 +- .../csrc/distributed/ParallelTcpStore.cpp | 50 +++++- .../csrc/distributed/ParallelTcpStore.hpp | 11 +- torch_npu/csrc/distributed/local_server.cpp | 167 ------------------ torch_npu/csrc/distributed/local_server.hpp | 62 ------- torch_npu/csrc/distributed/proxy.cpp | 38 ++-- torch_npu/csrc/distributed/proxy.hpp | 9 +- 12 files changed, 242 insertions(+), 432 deletions(-) rename torch_npu/csrc/distributed/{TcpClient.cpp => Client.cpp} (67%) rename torch_npu/csrc/distributed/{TcpClient.hpp => Client.hpp} (82%) delete mode 100644 torch_npu/csrc/distributed/LocalClient.cpp delete mode 100644 torch_npu/csrc/distributed/LocalClient.hpp delete mode 100644 torch_npu/csrc/distributed/local_server.cpp delete mode 100644 torch_npu/csrc/distributed/local_server.hpp diff --git a/torch_npu/csrc/distributed/TcpClient.cpp b/torch_npu/csrc/distributed/Client.cpp similarity index 67% rename from torch_npu/csrc/distributed/TcpClient.cpp rename to torch_npu/csrc/distributed/Client.cpp index 44f3a95041..242b746a5a 100644 --- a/torch_npu/csrc/distributed/TcpClient.cpp +++ b/torch_npu/csrc/distributed/Client.cpp @@ -24,18 +24,23 @@ #include #include "c10/util/Logging.h" -#include "TcpClient.hpp" +#include "Client.hpp" namespace c10d { namespace pta { static constexpr uint32_t READ_BUF_SZ = 256; -TcpClient::TcpClient(std::string host, uint16_t port) noexcept - : host_{ std::move(host) }, port_{ port }, socketFd_{ -1 } -{} +Client::Client(const std::string& socketPath) + : path_(std::move(socketPath) ), socketFd_(-1) {} -int TcpClient::Connect() noexcept -{ +Client::Client(const std::string& host, uint16_t port) + : host_{ std::move(host) }, port_{ port }, socketFd_(-1) {} + +Client::~Client() { + Close(); +} + +int Client::Connect() { socketFd_ = socket(AF_INET, SOCK_STREAM, 0); if (socketFd_ < 0) { LOG(ERROR) << "create tcp client socket failed " << errno << " : " << strerror(errno); @@ -72,9 +77,61 @@ int TcpClient::Connect() noexcept } return -1; + if (CreateSocket() != 0) { + return -1; + } + return ConnectSocket(); +} + +int Client::LocalConnect() { + socketFd_ = socket(AF_INET, SOCK_STREAM, 0); + if (socketFd_ < 0) { + LOG(ERROR) << "Create local socket failed: " << strerror(errno); + return -1; + } + struct sockaddr_un servAddr {}; + servAddr.sun_family = AF_UNIX; + const auto& path = std::get(address_); + strncpy(servAddr.sun_path, path.c_str(), sizeof(servAddr.sun_path) - 1); + + int lastError = 0; + auto endTime = std::chrono::steady_clock::now() + std::chrono::minutes(1); + while (std::chrono::steady_clock::now() < endTime) { + auto ret = connect(socketFd_, reinterpret_cast(&servAddr), sizeof(servAddr)); + if (ret == 0) { + return 0; + } + + if (errno != lastError) { + LOG(ERROR) << "connect socket to local server(" << path << ") failed " << errno << " : " << strerror(errno); + lastError = errno; + } + + if (errno == ETIMEDOUT) { + continue; + } + + if (errno == ECONNREFUSED) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + } +} + +int Client::Close() { + if (socketFd_ >= 0) { + int ret = close(socketFd_); + if (ret == 0) { + socketFd_ = -1; + } else { + LOG(ERROR) << "Close socket failed: " << strerror(errno); + } + return ret; + } + return 0; } -int TcpClient::Close() noexcept +int Client::Close() noexcept { auto ret = close(socketFd_); if (ret == 0) { @@ -87,7 +144,7 @@ int TcpClient::Close() noexcept return ret; } -int TcpClient::SyncCall(const StoreMessage &request, StoreMessage &response) noexcept +int Client::SyncCall(const StoreMessage &request, StoreMessage &response) noexcept { auto packedRequest = StoreMessagePacker::Pack(request); auto ret = write(socketFd_, packedRequest.data(), packedRequest.size()); @@ -134,7 +191,7 @@ int TcpClient::SyncCall(const StoreMessage &request, StoreMessage &response) noe return result; } -int TcpClient::SetReceiveTimeout(const std::chrono::milliseconds &value) const noexcept +int Client::SetReceiveTimeout(const std::chrono::milliseconds &value) const noexcept { if (value == std::chrono::milliseconds::zero()) { return 0; diff --git a/torch_npu/csrc/distributed/TcpClient.hpp b/torch_npu/csrc/distributed/Client.hpp similarity index 82% rename from torch_npu/csrc/distributed/TcpClient.hpp rename to torch_npu/csrc/distributed/Client.hpp index 822ff5e24e..0dbc6cbbcf 100644 --- a/torch_npu/csrc/distributed/TcpClient.hpp +++ b/torch_npu/csrc/distributed/Client.hpp @@ -23,15 +23,19 @@ namespace c10d { namespace pta { -class TcpClient { +class Client { public: - TcpClient(std::string host, uint16_t port) noexcept; + Client(const std::string& socketPath); // 用于 local client + Client(const std::string& host, uint16_t port); // 用于 tcp client + ~Client(); int Connect() noexcept; + int LocalConnect() noexcept; int Close() noexcept; int SyncCall(const StoreMessage &request, StoreMessage &response) noexcept; int SetReceiveTimeout(const std::chrono::milliseconds &value) const noexcept; private: + const std::string path_; const std::string host_; const uint16_t port_; int socketFd_; diff --git a/torch_npu/csrc/distributed/LocalClient.cpp b/torch_npu/csrc/distributed/LocalClient.cpp deleted file mode 100644 index 1560f0d12d..0000000000 --- a/torch_npu/csrc/distributed/LocalClient.cpp +++ /dev/null @@ -1,136 +0,0 @@ -#include -#include -#include - -#include -#include -#include -#include - -#include "c10/util/Logging.h" -#include "LocalClient.hpp" - -namespace c10d { -namespace pta { -static constexpr uint32_t READ_BUF_SZ = 256; - -LocalClient::LocalClient(std::string socketPath) noexcept - : socketPath_{ std::move(socketPath) }, socketFd_{ -1 } -{} - -int LocalClient::Connect() noexcept -{ - socketFd_ = socket(AF_UNIX, SOCK_STREAM, 0); - if (socketFd_ < 0) { - LOG(ERROR) << "create local client socket failed " << errno << " : " << strerror(errno); - return -1; - } - - struct sockaddr_un servAddr {}; - servAddr.sun_family = AF_UNIX; - strncpy(servAddr.sun_path, socketPath_.c_str(), sizeof(servAddr.sun_path) - 1); - - int lastError = 0; - auto endTime = std::chrono::steady_clock::now() + std::chrono::minutes(1); - while (std::chrono::steady_clock::now() < endTime) { - auto ret = connect(socketFd_, reinterpret_cast(&servAddr), sizeof(servAddr)); - if (ret == 0) { - return 0; - } - - if (errno != lastError) { - LOG(ERROR) << "connect socket to local server(" << socketPath_ << ") failed " << errno << " : " << - strerror(errno); - lastError = errno; - } - - if (errno == ENOENT || errno == ECONNREFUSED) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - continue; - } - - break; - } - - return -1; -} - -int LocalClient::Close() noexcept -{ - auto ret = close(socketFd_); - if (ret == 0) { - socketFd_ = -1; - return 0; - } - - LOG(ERROR) << "close socket to local server(" << socketPath_ << ") failed " << errno << " : " << - strerror(errno); - return ret; -} - -int LocalClient::SyncCall(const StoreMessage &request, StoreMessage &response) noexcept -{ - auto packedRequest = StoreMessagePacker::Pack(request); - auto ret = write(socketFd_, packedRequest.data(), packedRequest.size()); - if (ret < 0) { - LOG(ERROR) << "write data to local server(" << socketPath_ << ") failed " << errno << " : " << - strerror(errno); - return -1; - } - - uint8_t buffer[READ_BUF_SZ]; - std::vector responseBuf; - - bool finished = false; - int result = -1; - while (!finished) { - do { - ret = read(socketFd_, buffer, READ_BUF_SZ); - if (ret < 0) { - LOG(ERROR) << "read data from local server(" << socketPath_ << ") failed " << errno << " : " << - strerror(errno); - return -1; - } - - responseBuf.insert(responseBuf.end(), buffer, buffer + ret); - } while (!StoreMessagePacker::Full(responseBuf)); - - auto unpackRet = StoreMessagePacker::Unpack(responseBuf, response); - if (unpackRet < 0L) { - LOG(ERROR) << "unpack response data from local server(" << socketPath_ << ") failed " << unpackRet; - finished = true; - result = -1; - continue; - } - - if (response.mt == request.mt) { - finished = true; - result = 0; - continue; - } - - responseBuf.erase(responseBuf.begin(), responseBuf.begin() + unpackRet); - } - - return result; -} - -int LocalClient::SetReceiveTimeout(const std::chrono::milliseconds &value) const noexcept -{ - if (value == std::chrono::milliseconds::zero()) { - return 0; - } - struct timeval timeoutTV = { - .tv_sec = static_cast(value.count() / 1000), - .tv_usec = static_cast((value.count() % 1000) * 1000) - }; - - auto ret = setsockopt(socketFd_, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&timeoutTV), sizeof(timeoutTV)); - if (ret != 0) { - LOG(ERROR) << "set local connection receive timeout failed: " << errno << " : " << strerror(errno); - } - - return ret; -} -} // pta -} // c10d \ No newline at end of file diff --git a/torch_npu/csrc/distributed/LocalClient.hpp b/torch_npu/csrc/distributed/LocalClient.hpp deleted file mode 100644 index 4464ebcb37..0000000000 --- a/torch_npu/csrc/distributed/LocalClient.hpp +++ /dev/null @@ -1,23 +0,0 @@ -#pragma once - -#include -#include - -#include "StoreMessagePacker.hpp" - -namespace c10d { -namespace pta { -class LocalClient { -public: - explicit LocalClient(std::string socketPath) noexcept; - int Connect() noexcept; - int Close() noexcept; - int SyncCall(const StoreMessage &request, StoreMessage &response) noexcept; - int SetReceiveTimeout(const std::chrono::milliseconds &value) const noexcept; - -private: - const std::string socketPath_; - int socketFd_; -}; -} // pta -} // c10d \ No newline at end of file diff --git a/torch_npu/csrc/distributed/ParallelTcpServer.cpp b/torch_npu/csrc/distributed/ParallelTcpServer.cpp index 1d91ab6b02..7eb0fd5bcb 100644 --- a/torch_npu/csrc/distributed/ParallelTcpServer.cpp +++ b/torch_npu/csrc/distributed/ParallelTcpServer.cpp @@ -106,6 +106,10 @@ ParallelTcpServer::ParallelTcpServer(uint32_t threadNum, uint16_t port, ServerPr : threadNum_{ std::max(4U, threadNum) }, port_{ port }, process_{ std::move(process) } {} +ParallelTcpServer::ParallelTcpServer(uint32_t threadNum, const std::string& socketPath, CallBackFn callback) noexcept + : threadNum_{ std::max(4U, threadNum) }, socketPath_{ socketPath }, callback_{ std::move(callback) } +{} + int ParallelTcpServer::Start() noexcept { buffer_ = new (std::nothrow) uint8_t[4096]; @@ -155,6 +159,55 @@ int ParallelTcpServer::Start() noexcept return 0; } +int ParallelTcpServer::LocalStart() noexcept +{ + buffer_ = new (std::nothrow) uint8_t[4096]; + if (buffer_ == nullptr) { + LOG(ERROR) << "allocate buffer failed."; + return -1; + } + + listenSocket_ = CreateLocalSocket(socketPath_); + if (listenSocket_ < 0) { + delete[] buffer_; + buffer_ = nullptr; + return -1; + } + + epCtlFd_ = CreateEpoll(listenSocket_); + if (epCtlFd_ < 0) { + close(listenSocket_); + listenSocket_ = -1; + delete[] buffer_; + buffer_ = nullptr; + return -1; + } + + running_ = true; + epClientFds_.reserve(threadNum_); + clientThreads_.reserve(threadNum_); + auto initializeFailed = false; + for (auto i = 0U; i < threadNum_; i++) { + auto clientEpFd = CreateEpoll(); + if (clientEpFd < 0) { + LOG(ERROR) << "create new client epoll fd for index: " << i << " failed."; + initializeFailed = true; + break; + } + epClientFds_.emplace_back(clientEpFd); + clientThreads_.emplace_back([clientEpFd](ParallelTcpServer *server) { server->LoopProcessClients(clientEpFd); }, + this); + } + + ctlThread_ = std::thread{ [](ParallelTcpServer *server) { server->LoopProcessListenFd(); }, this }; + if (initializeFailed) { + Stop(); + return -1; + } + + return 0; +} + void ParallelTcpServer::Stop() noexcept { running_ = false; @@ -239,6 +292,42 @@ int ParallelTcpServer::CreateSocket(uint16_t port) noexcept return sockFd; } +int ParallelTcpServer::CreateLocalSocket(const std::string& socketPath) noexcept +{ + struct sockaddr_un servAddr {}; + servAddr.sun_family = AF_UNIX; + strncpy(servAddr.sun_path, socketPath.c_str(), sizeof(servAddr.sun_path) - 1); + + auto sockFd = ::socket(AF_UNIX, SOCK_STREAM, 0); + if (sockFd < 0) { + LOG(ERROR) << "create local socket fd failed " << errno << " : " << strerror(errno); + return -1; + } + + unlink(socketPath.c_str()); // Remove any existing socket file + + auto ret = ::bind(sockFd, reinterpret_cast(&servAddr), sizeof(servAddr)); + if (ret != 0) { + LOG(ERROR) << "bind local socket fd failed " << errno << " : " << strerror(errno); + close(sockFd); + return -1; + } + + ret = listen(sockFd, MAX_EVENT_COUNT); + if (ret != 0) { + LOG(ERROR) << "listen local socket fd failed " << errno << " : " << strerror(errno); + close(sockFd); + return -1; + } + + if (SetNonBlocking(sockFd) != 0) { + close(sockFd); + return -1; + } + + return sockFd; +} + int ParallelTcpServer::CreateEpoll(int targetFd) noexcept { auto fd = epoll_create(1); diff --git a/torch_npu/csrc/distributed/ParallelTcpServer.hpp b/torch_npu/csrc/distributed/ParallelTcpServer.hpp index ad95e94082..2505754357 100644 --- a/torch_npu/csrc/distributed/ParallelTcpServer.hpp +++ b/torch_npu/csrc/distributed/ParallelTcpServer.hpp @@ -99,9 +99,10 @@ using ServerProcFn = std::function &keys, int socket, int64_t waitCount) noexcept @@ -117,6 +118,7 @@ public: private: static int CreateSocket(uint16_t port) noexcept; + static int CreateLocalSocket(const std::string& socketPath) noexcept; static int CreateEpoll(int targetFd = -1) noexcept; @@ -133,7 +135,9 @@ private: private: const uint32_t threadNum_; const std::uint16_t port_; + const std::string socketPath_; const ServerProcFn process_; + const CallBackFn callback_; int listenSocket_{ -1 }; int epCtlFd_{ -1 }; std::thread ctlThread_; diff --git a/torch_npu/csrc/distributed/ParallelTcpStore.cpp b/torch_npu/csrc/distributed/ParallelTcpStore.cpp index 1b3943e1e9..d17e9dfde1 100644 --- a/torch_npu/csrc/distributed/ParallelTcpStore.cpp +++ b/torch_npu/csrc/distributed/ParallelTcpStore.cpp @@ -34,7 +34,7 @@ ParallelStoreServer::ParallelStoreServer(std::string initKey, uint16_t port, c10 InitializeHandlers(); server_ = std::make_unique(threadNum, port, - [this](int fd, const pta::StoreMessage &request) { return ProcessRequest(fd, request); }); + [this](int fd, const pta::StoreMessage &request) { return return this->proxyCallback_(request); }); if (server_->Start() != 0) { throw std::runtime_error{ std::string("start tcp server on port ").append(std::to_string(port)).append(" failed.") @@ -42,6 +42,19 @@ ParallelStoreServer::ParallelStoreServer(std::string initKey, uint16_t port, c10 } } +ParallelStoreServer::ParallelStoreServer(std::string socketPath, ProxyCallback callback) + : socketPath_(std::move(socketPath)), proxyCallback_(std::move(callback)) +{ + auto threadNum = 4U; + server_ = std::make_unique(threadNum, socketPath_, + [this](int fd, const pta::StoreMessage &request) { return ProcessLocalRequest(fd, request); }); + if (server_->LocalStart() != 0) { + throw std::runtime_error{ + std::string("start local server on socket ").append(socketPath_).append(" failed.") + }; + } +} + ParallelStoreServer::~ParallelStoreServer() noexcept { server_->Stop(); @@ -255,7 +268,7 @@ std::unordered_map> ParallelTc ParallelTcpStore::ParallelTcpStore(const std::string& host, const TCPStoreOptions& opts) : Store(opts.timeout) { - if (opts.is_server) { + if (opts.isServer) { if (opts.multiTenant) { server_ = GetSharedServer(initKey_, opts.port, opts.numWorkers); } else { @@ -271,16 +284,27 @@ ParallelTcpStore::ParallelTcpStore(const std::string& host, const TCPStoreOption proxy_->Start(); } else { // 如果 LOCAL_RANK 环境变量存在,则为 Worker - localClient_ = std::make_unique("/tmp/torch_dist_store"); - localClient_->Connect(); + localClient_ = std::make_unique("/tmp/torch_dist_store"); + if (localClient_.LocalConnect() != 0) { + throw std::runtime_error{ std::string("connect local client to server failed.") }; + } + } + if (opts.waitWorkers) { + IncreaseKey(initKey_, 1); + if (opts.isServer) { + server_->WaitWorkers(timeout_); + } } - - // ... 其他初始化逻辑 ... } ParallelTcpStore::~ParallelTcpStore() noexcept { - client_.Close(); + if (proxy_) { + proxy_->Stop(); + } + if (localClient_) { + localClient_->Close(); + } } void ParallelTcpStore::set(const std::string &key, const std::vector &value) @@ -288,7 +312,7 @@ void ParallelTcpStore::set(const std::string &key, const std::vector &v pta::StoreMessage request{ pta::MessageType::SET, key, value }; pta::StoreMessage response; std::lock_guard lockGuard{ clientMutex_ }; - auto ret = client_.SyncCall(request, response); + auto ret = localClient_.SyncCall(request, response); if (ret != 0) { throw std::runtime_error{ std::string("set key ").append(key).append(" failed.") }; } @@ -394,7 +418,15 @@ int64_t ParallelTcpStore::IncreaseKey(const std::string &key, int64_t value) pta::StoreMessage request{ pta::MessageType::ADD, key, pta::StoreMessagePacker::PackPod(value) }; pta::StoreMessage response; std::lock_guard lockGuard{ clientMutex_ }; - auto ret = client_.SyncCall(request, response); + int ret = -1; + if (proxy_) { + ret = proxy_->SyncCall(request, response); + } else if (localClient_) { + ret = localClient_->SyncCall(request, response); + } else { + throw std::runtime_error{ "No valid client available for operation." }; + } + if (ret != 0) { throw std::runtime_error{ std::string("add key ").append(key).append(" failed.") }; } diff --git a/torch_npu/csrc/distributed/ParallelTcpStore.hpp b/torch_npu/csrc/distributed/ParallelTcpStore.hpp index 0841136eb5..b7511afd26 100644 --- a/torch_npu/csrc/distributed/ParallelTcpStore.hpp +++ b/torch_npu/csrc/distributed/ParallelTcpStore.hpp @@ -25,15 +25,19 @@ #include #include "c10d/TCPStore.hpp" -#include "TcpClient.hpp" +#include "Client.hpp" +#include "Proxy.hpp" #include "ParallelTcpServer.hpp" namespace c10d { namespace pta { class ParallelStoreServer { public: + using MessageCallback = std::function; ParallelStoreServer(std::string initKey, uint16_t port, c10::optional numWorkers); + ParallelStoreServer(std::string socketPath, ProxyCallback callback); virtual ~ParallelStoreServer() noexcept; void WaitWorkers(const std::chrono::milliseconds &timeout) noexcept; + void setMessageCallback(MessageCallback callback); private: pta::StoreMessage ProcessRequest(int fd, const pta::StoreMessage &request) noexcept; @@ -49,6 +53,8 @@ private: bool CheckAllKeysExistInLock(const std::vector &keys) noexcept; private: + ProxyCallback proxyCallback_; + std::string socketPath_; using RequestHandler = std::function; std::unique_ptr server_; std::unordered_map requestHandlers_; @@ -89,7 +95,8 @@ private: c10::optional numWorkers); private: - pta::TcpClient client_; + std::unique_ptr proxy_; + std::unique_ptr localClient_; std::shared_ptr server_; std::mutex clientMutex_; std::condition_variable initWaitCond_; diff --git a/torch_npu/csrc/distributed/local_server.cpp b/torch_npu/csrc/distributed/local_server.cpp deleted file mode 100644 index e267e3ade0..0000000000 --- a/torch_npu/csrc/distributed/local_server.cpp +++ /dev/null @@ -1,167 +0,0 @@ -// LocalServer.cpp -#include "LocalServer.hpp" -#include -#include -#include -#include -#include -#include "c10/util/Logging.h" - -namespace c10d { -namespace pta { - -LocalServer::LocalServer(std::string socketPath, c10::optional numWorkers) - : numWorkers_(numWorkers), socketPath_(std::move(socketPath)), serverSocket_(-1), running_(false) -{ - InitializeHandlers(); -} - -LocalServer::~LocalServer() noexcept -{ - Stop(); -} - -void LocalServer::Start() -{ - serverSocket_ = socket(AF_UNIX, SOCK_STREAM, 0); - if (serverSocket_ < 0) { - throw std::runtime_error("Failed to create Unix domain socket"); - } - - struct sockaddr_un addr; - memset(&addr, 0, sizeof(addr)); - addr.sun_family = AF_UNIX; - strncpy(addr.sun_path, socketPath_.c_str(), sizeof(addr.sun_path) - 1); - - unlink(socketPath_.c_str()); // Remove any existing socket file - - if (bind(serverSocket_, (struct sockaddr*)&addr, sizeof(addr)) < 0) { - close(serverSocket_); - throw std::runtime_error("Failed to bind Unix domain socket"); - } - - if (listen(serverSocket_, 5) < 0) { - close(serverSocket_); - throw std::runtime_error("Failed to listen on Unix domain socket"); - } - - running_ = true; - serverThread_ = std::thread(&LocalServer::RunServer, this); -} - -void LocalServer::Stop() -{ - running_ = false; - if (serverThread_.joinable()) { - serverThread_.join(); - } - if (serverSocket_ >= 0) { - close(serverSocket_); - serverSocket_ = -1; - } - unlink(socketPath_.c_str()); -} - -void LocalServer::WaitWorkers(const std::chrono::milliseconds &timeout) noexcept -{ - if (numWorkers_ == c10::nullopt) { - return; - } - - const auto start = std::chrono::steady_clock::now(); - while (!workersReady_) { - std::unique_lock lockGuard{ initWaitMutex_ }; - if (timeout == Store::kNoTimeout) { - initWaitCond_.wait(lockGuard); - } else { - initWaitCond_.wait_until(lockGuard, start + timeout); - } - } -} - -void LocalServer::RunServer() -{ - while (running_) { - int clientSocket = accept(serverSocket_, nullptr, nullptr); - if (clientSocket < 0) { - if (errno != EINTR) { - LOG(ERROR) << "Failed to accept client connection: " << strerror(errno); - } - continue; - } - - std::thread clientThread(&LocalServer::HandleClient, this, clientSocket); - clientThread.detach(); - } -} - -void LocalServer::HandleClient(int clientSocket) -{ - std::vector buffer(1024); - while (running_) { - ssize_t bytesRead = recv(clientSocket, buffer.data(), buffer.size(), 0); - if (bytesRead <= 0) { - break; - } - - StoreMessage request; - auto unpackResult = StoreMessagePacker::Unpack(buffer, request); - if (unpackResult <= 0) { - LOG(ERROR) << "Failed to unpack client request"; - continue; - } - - StoreMessage response = ProcessRequest(clientSocket, request); - auto packedResponse = StoreMessagePacker::Pack(response); - send(clientSocket, packedResponse.data(), packedResponse.size(), 0); - } - close(clientSocket); -} - -pta::StoreMessage LocalServer::ProcessRequest(int fd, const pta::StoreMessage &request) noexcept -{ - auto pos = requestHandlers_.find(request.mt); - if (pos != requestHandlers_.end()) { - auto response = pos->second(fd, request); - - // 在成功处理请求后通知 Proxy - if (proxyCallback_) { - proxyCallback_(request); - } - - return response; - } - - LOG(ERROR) << "unsupported message type " << static_cast(request.mt); - return request; -} - -// Implement other ProcessXXXRequest methods similar to ParallelStoreServer... - -void LocalServer::InitializeHandlers() noexcept -{ - requestHandlers_.emplace(pta::MessageType::SET, - [this](int fd, const pta::StoreMessage &req) { return ProcessSetRequest(fd, req); }); - requestHandlers_.emplace(pta::MessageType::COMPARE_SET, - [this](int fd, const pta::StoreMessage &req) { return ProcessCompareSetRequest(fd, req); }); - requestHandlers_.emplace(pta::MessageType::GET, - [this](int fd, const pta::StoreMessage &req) { return ProcessGetRequest(fd, req); }); - requestHandlers_.emplace(pta::MessageType::ADD, - [this](int fd, const pta::StoreMessage &req) { return ProcessAddRequest(fd, req); }); - requestHandlers_.emplace(pta::MessageType::CHECK, - [this](int fd, const pta::StoreMessage &req) { return ProcessCheckRequest(fd, req); }); - requestHandlers_.emplace(pta::MessageType::WAIT, - [this](int fd, const pta::StoreMessage &req) { return ProcessWaitKeysRequest(fd, req); }); - requestHandlers_.emplace(pta::MessageType::GET_NUM_KEYS, - [this](int fd, const pta::StoreMessage &req) { return ProcessGetNumKeyRequest(fd, req); }); - requestHandlers_.emplace(pta::MessageType::DELETE_KEY, - [this](int fd, const pta::StoreMessage &req) { return ProcessDeleteRequest(fd, req); }); -} - -bool LocalServer::CheckAllKeysExistInLock(const std::vector &keys) noexcept -{ - return std::all_of(keys.begin(), keys.end(), [this](const std::string &key) { return keyStore_.count(key) > 0; }); -} - -} // namespace pta -} // namespace c10d \ No newline at end of file diff --git a/torch_npu/csrc/distributed/local_server.hpp b/torch_npu/csrc/distributed/local_server.hpp deleted file mode 100644 index f22ef7fdb8..0000000000 --- a/torch_npu/csrc/distributed/local_server.hpp +++ /dev/null @@ -1,62 +0,0 @@ -// LocalServer.hpp -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include "StoreMessage.hpp" - -namespace c10d { -namespace pta { - -class LocalServer { -public: -using ProxyCallback = std::function; - LocalServer(std::string socketPath, c10::optional numWorkers); - virtual ~LocalServer() noexcept; - void Start(); - void Stop(); - void WaitWorkers(const std::chrono::milliseconds &timeout) noexcept; - void setProxyCallback(ProxyCallback callback) { - proxyCallback_ = std::move(callback); - } - -private: - - void RunServer(); - void HandleClient(int clientSocket); - pta::StoreMessage ProcessRequest(int fd, const pta::StoreMessage &request) noexcept; - pta::StoreMessage ProcessGetRequest(int fd, const pta::StoreMessage &request) noexcept; - pta::StoreMessage ProcessSetRequest(int fd, const pta::StoreMessage &request) noexcept; - pta::StoreMessage ProcessAddRequest(int fd, const pta::StoreMessage &request) noexcept; - pta::StoreMessage ProcessCheckRequest(int fd, const pta::StoreMessage &request) noexcept; - pta::StoreMessage ProcessDeleteRequest(int fd, const pta::StoreMessage &request) noexcept; - pta::StoreMessage ProcessCompareSetRequest(int fd, const pta::StoreMessage &request) noexcept; - pta::StoreMessage ProcessGetNumKeyRequest(int fd, const pta::StoreMessage &request) noexcept; - pta::StoreMessage ProcessWaitKeysRequest(int fd, const pta::StoreMessage &request) noexcept; - void InitializeHandlers() noexcept; - bool CheckAllKeysExistInLock(const std::vector &keys) noexcept; - - using RequestHandler = std::function; - std::unordered_map requestHandlers_; - std::unordered_map> keyStore_; - SpinLock serverLock_; - std::mutex initWaitMutex_; - std::condition_variable initWaitCond_; - std::atomic workersReady_{ false }; - const c10::optional numWorkers_; - const std::string socketPath_; - const std::string initKey_ = "init/"; - const std::string keyPrefix_ = "/"; - int serverSocket_; - std::atomic running_; - std::thread serverThread_; - ProxyCallback proxyCallback_; -}; - -} // namespace pta -} // namespace c10d \ No newline at end of file diff --git a/torch_npu/csrc/distributed/proxy.cpp b/torch_npu/csrc/distributed/proxy.cpp index f5de59a5bb..ab278f0683 100644 --- a/torch_npu/csrc/distributed/proxy.cpp +++ b/torch_npu/csrc/distributed/proxy.cpp @@ -6,41 +6,45 @@ namespace c10d { namespace pta { Proxy::Proxy(const std::string& localSocketPath, const std::string& tcpHost, uint16_t tcpPort) - : localServer_(std::make_unique(localSocketPath, c10::nullopt)) + : localServer_(std::make_unique(localSocketPath)), + tcpClient_(std::make_unique(tcpHost, tcpPort)) { localServer_->setProxyCallback([this](const StoreMessage& msg) { - this->HandleLocalServerMessage(msg); + return this->HandleLocalServerMessage(msg); }); - if (!tcpHost.empty() && tcpPort != 0) { - tcpClient_ = std::make_unique(tcpHost, tcpPort); - } +} + +Proxy::~Proxy() +{ + Stop(); } void Proxy::Start() { - localServer_->Start(); - if (tcpClient_) { - tcpClient_->Connect(); + localServer_->LocalStart(); + if (tcpClient_->Connect() != 0) { + throw std::runtime_error("Failed to connect to TCP server"); } } void Proxy::Stop() { localServer_->Stop(); - if (tcpClient_) { - tcpClient_->Close(); - } + tcpClient_->Close(); +} + +int Proxy::SyncCall(const pta::StoreMessage& request, pta::StoreMessage& response) { + return tcpClient_->SyncCall(request, response); } -void Proxy::HandleLocalServerMessage(const StoreMessage& message) +StoreMessage Proxy::HandleLocalServerMessage(const StoreMessage& message) { - if (tcpClient_) { - StoreMessage response; - tcpClient_->SyncCall(message, response); - // 如果需要,可以将响应发送回 LocalServer - // localServer_->HandleResponse(response); + StoreMessage response; + if (tcpClient_->SyncCall(message, response) != 0) { + throw std::runtime_error("Failed to sync call with TCP server"); } + return response; } } // namespace pta diff --git a/torch_npu/csrc/distributed/proxy.hpp b/torch_npu/csrc/distributed/proxy.hpp index e9975d9fb3..96cfe23f59 100644 --- a/torch_npu/csrc/distributed/proxy.hpp +++ b/torch_npu/csrc/distributed/proxy.hpp @@ -2,8 +2,8 @@ #include #include -#include "LocalServer.hpp" -#include "TcpClient.hpp" +#include "ParallelTcpStore.hpp" +#include "Client.hpp" namespace c10d { namespace pta { @@ -13,12 +13,13 @@ public: Proxy(const std::string& localSocketPath, const std::string& tcpHost, uint16_t tcpPort); void Start(); void Stop(); + void SyncCall(); private: void HandleLocalServerMessage(const StoreMessage& message); - std::unique_ptr localServer_; - std::unique_ptr tcpClient_; + std::unique_ptr localServer_; + std::unique_ptr tcpClient_; }; } // namespace pta -- Gitee From 3bf70029eb57e565be590a0d346ce01fd7b044ba Mon Sep 17 00:00:00 2001 From: wuyangyu Date: Wed, 21 Aug 2024 10:27:16 +0800 Subject: [PATCH 03/14] change8.21 --- torch_npu/csrc/distributed/Client.cpp | 22 +++---- torch_npu/csrc/distributed/Client.hpp | 1 - .../csrc/distributed/ParallelTcpServer.cpp | 57 +++---------------- .../csrc/distributed/ParallelTcpServer.hpp | 1 + .../csrc/distributed/ParallelTcpStore.cpp | 18 +++--- .../csrc/distributed/ParallelTcpStore.hpp | 6 +- torch_npu/csrc/distributed/proxy.cpp | 1 + torch_npu/csrc/distributed/proxy.hpp | 2 +- 8 files changed, 31 insertions(+), 77 deletions(-) diff --git a/torch_npu/csrc/distributed/Client.cpp b/torch_npu/csrc/distributed/Client.cpp index c901a3c2d2..e6e0bbbff7 100644 --- a/torch_npu/csrc/distributed/Client.cpp +++ b/torch_npu/csrc/distributed/Client.cpp @@ -37,10 +37,6 @@ Client::Client(const std::string& socketPath) Client::Client(const std::string& host, uint16_t port) : host_{ std::move(host) }, port_{ port }, socketFd_(-1) {} -Client::~Client() { - Close(); -} - int Client::Connect() { socketFd_ = socket(AF_INET, SOCK_STREAM, 0); if (socketFd_ < 0) { @@ -117,17 +113,17 @@ int Client::LocalConnect() { } -int Client::Close() { +int Client::LocalClose() { if (socketFd_ >= 0) { - int ret = close(socketFd_); - if (ret == 0) { - socketFd_ = -1; - } else { - LOG(ERROR) << "Close socket failed: " << strerror(errno); - } - return ret; + auto ret = close(socketFd_); + if (ret == 0) { + socketFd_ = -1; + return 0; } - return 0; + + LOG(ERROR) << "close socket to server(" << path_ << ") failed " << errno << " : " << + strerror(errno); + return ret; } int Client::Close() noexcept diff --git a/torch_npu/csrc/distributed/Client.hpp b/torch_npu/csrc/distributed/Client.hpp index 0dbc6cbbcf..a0564abab1 100644 --- a/torch_npu/csrc/distributed/Client.hpp +++ b/torch_npu/csrc/distributed/Client.hpp @@ -27,7 +27,6 @@ class Client { public: Client(const std::string& socketPath); // 用于 local client Client(const std::string& host, uint16_t port); // 用于 tcp client - ~Client(); int Connect() noexcept; int LocalConnect() noexcept; int Close() noexcept; diff --git a/torch_npu/csrc/distributed/ParallelTcpServer.cpp b/torch_npu/csrc/distributed/ParallelTcpServer.cpp index 55337b7403..3069c56942 100644 --- a/torch_npu/csrc/distributed/ParallelTcpServer.cpp +++ b/torch_npu/csrc/distributed/ParallelTcpServer.cpp @@ -108,58 +108,11 @@ ParallelTcpServer::ParallelTcpServer(uint32_t threadNum, uint16_t port, ServerPr ParallelTcpServer::ParallelTcpServer(uint32_t threadNum, const std::string& socketPath, CallBackFn callback) noexcept : threadNum_{ std::max(4U, threadNum) }, socketPath_{ socketPath }, callback_{ std::move(callback) } -{} - -int ParallelTcpServer::Start() noexcept { - buffer_ = new (std::nothrow) uint8_t[4096]; - if (buffer_ == nullptr) { - LOG(ERROR) << "allocate buffer failed."; - return -1; - } - - listenSocket_ = CreateSocket(port_); - if (listenSocket_ < 0) { - delete[] buffer_; - buffer_ = nullptr; - return -1; - } - - epCtlFd_ = CreateEpoll(listenSocket_); - if (epCtlFd_ < 0) { - close(listenSocket_); - listenSocket_ = -1; - delete[] buffer_; - buffer_ = nullptr; - return -1; - } - - running_ = true; - epClientFds_.reserve(threadNum_); - clientThreads_.reserve(threadNum_); - auto initializeFailed = false; - for (auto i = 0U; i < threadNum_; i++) { - auto clientEpFd = CreateEpoll(); - if (clientEpFd < 0) { - LOG(ERROR) << "create new client epoll fd for index: " << i << " failed."; - initializeFailed = true; - break; - } - epClientFds_.emplace_back(clientEpFd); - clientThreads_.emplace_back([clientEpFd](ParallelTcpServer *server) { server->LoopProcessClients(clientEpFd); }, - this); - } - - ctlThread_ = std::thread{ [](ParallelTcpServer *server) { server->LoopProcessListenFd(); }, this }; - if (initializeFailed) { - Stop(); - return -1; - } - - return 0; + isLocalServer_ = true; } -int ParallelTcpServer::LocalStart() noexcept +int ParallelTcpServer::Start() noexcept { buffer_ = new (std::nothrow) uint8_t[4096]; if (buffer_ == nullptr) { @@ -167,7 +120,11 @@ int ParallelTcpServer::LocalStart() noexcept return -1; } - listenSocket_ = CreateLocalSocket(socketPath_); + if(isLocalServer_){ + listenSocket_ = CreateLocalSocket(socketPath_); + }else{ + listenSocket_ = CreateSocket(port_); + } if (listenSocket_ < 0) { delete[] buffer_; buffer_ = nullptr; diff --git a/torch_npu/csrc/distributed/ParallelTcpServer.hpp b/torch_npu/csrc/distributed/ParallelTcpServer.hpp index 9448f37e84..55efcbfdf2 100644 --- a/torch_npu/csrc/distributed/ParallelTcpServer.hpp +++ b/torch_npu/csrc/distributed/ParallelTcpServer.hpp @@ -141,6 +141,7 @@ private: const CallBackFn callback_; int listenSocket_{ -1 }; int epCtlFd_{ -1 }; + bool isLocalServer_{ false }; std::thread ctlThread_; std::vector epClientFds_; std::vector clientThreads_; diff --git a/torch_npu/csrc/distributed/ParallelTcpStore.cpp b/torch_npu/csrc/distributed/ParallelTcpStore.cpp index 665a26665b..41d9c6fd84 100644 --- a/torch_npu/csrc/distributed/ParallelTcpStore.cpp +++ b/torch_npu/csrc/distributed/ParallelTcpStore.cpp @@ -34,7 +34,7 @@ ParallelStoreServer::ParallelStoreServer(std::string initKey, uint16_t port, c10 InitializeHandlers(); server_ = std::make_unique(threadNum, port, - [this](int fd, const pta::StoreMessage &request) { return ProcessLocalRequest(fd, request); }); + [this](int fd, const pta::StoreMessage &request) { return ProcessRequest(fd, request); }); if (server_->Start() != 0) { throw std::runtime_error{ std::string("start tcp server on port ").append(std::to_string(port)).append(" failed.") @@ -48,7 +48,7 @@ ParallelStoreServer::ParallelStoreServer(std::string socketPath, CallBackFn call auto threadNum = 4U; server_ = std::make_unique(threadNum, socketPath_, [this](const pta::StoreMessage &request) { return this->callback_(request); }); - if (server_->LocalStart() != 0) { + if (server_->Start() != 0) { throw std::runtime_error{ std::string("start local server on socket ").append(socketPath_).append(" failed.") }; @@ -284,8 +284,8 @@ ParallelTcpStore::ParallelTcpStore(const std::string& host, const TCPStoreOption proxy_->Start(); } else { // 如果 LOCAL_RANK 环境变量存在,则为 Worker - localClient_ = std::make_unique("/tmp/torch_dist_store"); - if (localClient_.LocalConnect() != 0) { + Client_ = std::make_unique("/tmp/torch_dist_store"); + if (Client_.LocalConnect() != 0) { throw std::runtime_error{ std::string("connect local client to server failed.") }; } } @@ -302,8 +302,8 @@ ParallelTcpStore::~ParallelTcpStore() noexcept if (proxy_) { proxy_->Stop(); } - if (localClient_) { - localClient_->Close(); + if (Client_) { + Client_->LocalClose(); } } @@ -312,7 +312,7 @@ void ParallelTcpStore::set(const std::string &key, const std::vector &v pta::StoreMessage request{ pta::MessageType::SET, key, value }; pta::StoreMessage response; std::lock_guard lockGuard{ clientMutex_ }; - auto ret = localClient_.SyncCall(request, response); + auto ret = Client_.SyncCall(request, response); if (ret != 0) { throw std::runtime_error{ std::string("set key ").append(key).append(" failed.") }; } @@ -421,8 +421,8 @@ int64_t ParallelTcpStore::IncreaseKey(const std::string &key, int64_t value) int ret = -1; if (proxy_) { ret = proxy_->SyncCall(request, response); - } else if (localClient_) { - ret = localClient_->SyncCall(request, response); + } else if (Client_) { + ret = Client_->SyncCall(request, response); } else { throw std::runtime_error{ "No valid client available for operation." }; } diff --git a/torch_npu/csrc/distributed/ParallelTcpStore.hpp b/torch_npu/csrc/distributed/ParallelTcpStore.hpp index 1674f1bd32..ca0839baea 100644 --- a/torch_npu/csrc/distributed/ParallelTcpStore.hpp +++ b/torch_npu/csrc/distributed/ParallelTcpStore.hpp @@ -30,9 +30,9 @@ #include "ParallelTcpServer.hpp" namespace c10d { namespace pta { +using CallBackFn = std::function; class ParallelStoreServer { -public: - using CallBackFn = std::function; +public: ParallelStoreServer(std::string initKey, uint16_t port, c10::optional numWorkers); ParallelStoreServer(std::string socketPath, CallBackFn callback); virtual ~ParallelStoreServer() noexcept; @@ -95,7 +95,7 @@ private: private: std::unique_ptr proxy_; - std::unique_ptr localClient_; + std::unique_ptr Client_; std::shared_ptr server_; std::mutex clientMutex_; std::condition_variable initWaitCond_; diff --git a/torch_npu/csrc/distributed/proxy.cpp b/torch_npu/csrc/distributed/proxy.cpp index 92c4c5adef..fc4ba332e8 100644 --- a/torch_npu/csrc/distributed/proxy.cpp +++ b/torch_npu/csrc/distributed/proxy.cpp @@ -1,4 +1,5 @@ #include "Proxy.hpp" +#include "ParallelStoreServer.hpp" #include "c10/util/Exception.h" #include diff --git a/torch_npu/csrc/distributed/proxy.hpp b/torch_npu/csrc/distributed/proxy.hpp index 4502f3b730..fa0f4189a3 100644 --- a/torch_npu/csrc/distributed/proxy.hpp +++ b/torch_npu/csrc/distributed/proxy.hpp @@ -13,7 +13,7 @@ public: Proxy(const std::string& localSocketPath, const std::string& tcpHost, uint16_t tcpPort); void Start(); void Stop(); - void SyncCall(); + int SyncCall(); private: void HandleLocalServerMessage(const StoreMessage& message); -- Gitee From 65b2890d1a45e53ead0d40cd664bf3debeaf536e Mon Sep 17 00:00:00 2001 From: wuyangyu Date: Wed, 21 Aug 2024 10:30:10 +0800 Subject: [PATCH 04/14] change8.21 --- torch_npu/csrc/distributed/ParallelTcpServer.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_npu/csrc/distributed/ParallelTcpServer.cpp b/torch_npu/csrc/distributed/ParallelTcpServer.cpp index 3069c56942..fd1ae2bb1e 100644 --- a/torch_npu/csrc/distributed/ParallelTcpServer.cpp +++ b/torch_npu/csrc/distributed/ParallelTcpServer.cpp @@ -422,7 +422,7 @@ void ParallelTcpServer::ProcessClientEvent(int epFd, int fd, uint32_t event, if (event & EPOLLIN) { pos->second.ReceiveData(); while (pos->second.HasNextReq()) { - auto response = socketPath_.empty() ? process_(fd, pos->second.NextRequest()) : callback_(pos->second.NextRequest()); + auto response = isLocalServer_ ? callback_(pos->second.NextRequest()) : process_(fd, pos->second.NextRequest()); pos->second.SendResponse(response); } -- Gitee From 31a3a81ec3eab7d82666010cf9b64d895538a0e4 Mon Sep 17 00:00:00 2001 From: wuyangyu Date: Wed, 21 Aug 2024 10:52:17 +0800 Subject: [PATCH 05/14] change8.21 --- torch_npu/csrc/distributed/Client.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_npu/csrc/distributed/Client.cpp b/torch_npu/csrc/distributed/Client.cpp index e6e0bbbff7..7c2e940abb 100644 --- a/torch_npu/csrc/distributed/Client.cpp +++ b/torch_npu/csrc/distributed/Client.cpp @@ -114,7 +114,6 @@ int Client::LocalConnect() { } int Client::LocalClose() { - if (socketFd_ >= 0) { auto ret = close(socketFd_); if (ret == 0) { socketFd_ = -1; -- Gitee From 7984e1467318d194346af5e6dc730b89afeab079 Mon Sep 17 00:00:00 2001 From: wuyangyu Date: Wed, 21 Aug 2024 14:24:14 +0800 Subject: [PATCH 06/14] change8.21 --- torch_npu/csrc/distributed/Client.cpp | 8 ++--- torch_npu/csrc/distributed/Client.hpp | 7 +++-- .../csrc/distributed/ParallelTcpServer.cpp | 12 ++++---- .../csrc/distributed/ParallelTcpServer.hpp | 6 ++-- .../csrc/distributed/ParallelTcpStore.cpp | 30 +++++++++---------- .../csrc/distributed/ParallelTcpStore.hpp | 6 ++-- torch_npu/csrc/distributed/proxy.cpp | 9 ++---- torch_npu/csrc/distributed/proxy.hpp | 2 +- 8 files changed, 37 insertions(+), 43 deletions(-) diff --git a/torch_npu/csrc/distributed/Client.cpp b/torch_npu/csrc/distributed/Client.cpp index 7c2e940abb..944e10595f 100644 --- a/torch_npu/csrc/distributed/Client.cpp +++ b/torch_npu/csrc/distributed/Client.cpp @@ -31,8 +31,8 @@ namespace c10d { namespace pta { static constexpr uint32_t READ_BUF_SZ = 256; -Client::Client(const std::string& socketPath) - : path_(std::move(socketPath) ), socketFd_(-1) {} +Client::Client(const std::string& localSocketPath) + : localSocketPath_(std::move(localSocketPath) ), socketFd_(-1) {} Client::Client(const std::string& host, uint16_t port) : host_{ std::move(host) }, port_{ port }, socketFd_(-1) {} @@ -84,7 +84,7 @@ int Client::LocalConnect() { } struct sockaddr_un servAddr {}; servAddr.sun_family = AF_UNIX; - strncpy(servAddr.sun_path, path_.c_str(), sizeof(servAddr.sun_path) - 1); + strncpy(servAddr.sun_path, localSocketPath_.c_str(), sizeof(servAddr.sun_path) - 1); int lastError = 0; auto endTime = std::chrono::steady_clock::now() + std::chrono::minutes(1); @@ -95,7 +95,7 @@ int Client::LocalConnect() { } if (errno != lastError) { - LOG(ERROR) << "connect socket to local server(" << path << ") failed " << errno << " : " << strerror(errno); + LOG(ERROR) << "connect socket to local server(" << localSocketPath_ << ") failed " << errno << " : " << strerror(errno); lastError = errno; } diff --git a/torch_npu/csrc/distributed/Client.hpp b/torch_npu/csrc/distributed/Client.hpp index a0564abab1..707e99fab7 100644 --- a/torch_npu/csrc/distributed/Client.hpp +++ b/torch_npu/csrc/distributed/Client.hpp @@ -25,16 +25,17 @@ namespace c10d { namespace pta { class Client { public: - Client(const std::string& socketPath); // 用于 local client - Client(const std::string& host, uint16_t port); // 用于 tcp client + Client(const std::string& localSocketPath); // for local client + Client(const std::string& host, uint16_t port); // for tcp client int Connect() noexcept; int LocalConnect() noexcept; int Close() noexcept; + int LocalClose() noexcept; int SyncCall(const StoreMessage &request, StoreMessage &response) noexcept; int SetReceiveTimeout(const std::chrono::milliseconds &value) const noexcept; private: - const std::string path_; + const std::string localSocketPath_; const std::string host_; const uint16_t port_; int socketFd_; diff --git a/torch_npu/csrc/distributed/ParallelTcpServer.cpp b/torch_npu/csrc/distributed/ParallelTcpServer.cpp index fd1ae2bb1e..1014f3c309 100644 --- a/torch_npu/csrc/distributed/ParallelTcpServer.cpp +++ b/torch_npu/csrc/distributed/ParallelTcpServer.cpp @@ -106,8 +106,8 @@ ParallelTcpServer::ParallelTcpServer(uint32_t threadNum, uint16_t port, ServerPr : threadNum_{ std::max(4U, threadNum) }, port_{ port }, process_{ std::move(process) } {} -ParallelTcpServer::ParallelTcpServer(uint32_t threadNum, const std::string& socketPath, CallBackFn callback) noexcept - : threadNum_{ std::max(4U, threadNum) }, socketPath_{ socketPath }, callback_{ std::move(callback) } +ParallelTcpServer::ParallelTcpServer(uint32_t threadNum, const std::string& localSocketPath, CallBackFn callback) noexcept + : threadNum_{ std::max(4U, threadNum) }, localSocketPath_{ localSocketPath }, callback_{ std::move(callback) } { isLocalServer_ = true; } @@ -121,7 +121,7 @@ int ParallelTcpServer::Start() noexcept } if(isLocalServer_){ - listenSocket_ = CreateLocalSocket(socketPath_); + listenSocket_ = CreateLocalSocket(localSocketPath_); }else{ listenSocket_ = CreateSocket(port_); } @@ -249,11 +249,11 @@ int ParallelTcpServer::CreateSocket(uint16_t port) noexcept return sockFd; } -int ParallelTcpServer::CreateLocalSocket(const std::string& socketPath) noexcept +int ParallelTcpServer::CreateLocalSocket(const std::string& localSocketPath) noexcept { struct sockaddr_un servAddr {}; servAddr.sun_family = AF_UNIX; - strncpy(servAddr.sun_path, socketPath.c_str(), sizeof(servAddr.sun_path) - 1); + strncpy(servAddr.sun_path, localSocketPath.c_str(), sizeof(servAddr.sun_path) - 1); auto sockFd = ::socket(AF_UNIX, SOCK_STREAM, 0); if (sockFd < 0) { @@ -261,7 +261,7 @@ int ParallelTcpServer::CreateLocalSocket(const std::string& socketPath) noexcept return -1; } - unlink(socketPath.c_str()); // Remove any existing socket file + unlink(localSocketPath.c_str()); // Remove any existing socket file auto ret = ::bind(sockFd, reinterpret_cast(&servAddr), sizeof(servAddr)); if (ret != 0) { diff --git a/torch_npu/csrc/distributed/ParallelTcpServer.hpp b/torch_npu/csrc/distributed/ParallelTcpServer.hpp index 55efcbfdf2..2723e828c3 100644 --- a/torch_npu/csrc/distributed/ParallelTcpServer.hpp +++ b/torch_npu/csrc/distributed/ParallelTcpServer.hpp @@ -100,7 +100,7 @@ using CallBackFn = std::function; class ParallelTcpServer { public: explicit ParallelTcpServer(uint32_t threadNum, uint16_t port, ServerProcFn process) noexcept; - explicit ParallelTcpServer(uint32_t threadNum, const std::string& socketPath, CallBackFn callback) noexcept; + explicit ParallelTcpServer(uint32_t threadNum, const std::string& localSocketPath, CallBackFn callback) noexcept; int Start() noexcept; int LocalStart() noexcept; @@ -119,7 +119,7 @@ public: private: static int CreateSocket(uint16_t port) noexcept; - static int CreateLocalSocket(const std::string& socketPath) noexcept; + static int CreateLocalSocket(const std::string& localSocketPath) noexcept; static int CreateEpoll(int targetFd = -1) noexcept; @@ -136,7 +136,7 @@ private: private: const uint32_t threadNum_; const std::uint16_t port_; - const std::string socketPath_; + const std::string localSocketPath_; const ServerProcFn process_; const CallBackFn callback_; int listenSocket_{ -1 }; diff --git a/torch_npu/csrc/distributed/ParallelTcpStore.cpp b/torch_npu/csrc/distributed/ParallelTcpStore.cpp index 41d9c6fd84..deb8e5face 100644 --- a/torch_npu/csrc/distributed/ParallelTcpStore.cpp +++ b/torch_npu/csrc/distributed/ParallelTcpStore.cpp @@ -42,15 +42,15 @@ ParallelStoreServer::ParallelStoreServer(std::string initKey, uint16_t port, c10 } } -ParallelStoreServer::ParallelStoreServer(std::string socketPath, CallBackFn callback) - : socketPath_(std::move(socketPath)), callback_(std::move(callback)) +ParallelStoreServer::ParallelStoreServer(std::string localSocketPath, CallBackFn callback) + : localSocketPath_(std::move(localSocketPath)), callback_(std::move(callback)) { auto threadNum = 4U; - server_ = std::make_unique(threadNum, socketPath_, + server_ = std::make_unique(threadNum, localSocketPath_, [this](const pta::StoreMessage &request) { return this->callback_(request); }); if (server_->Start() != 0) { throw std::runtime_error{ - std::string("start local server on socket ").append(socketPath_).append(" failed.") + std::string("start local server on socket ").append(localSocketPath_).append(" failed.") }; } } @@ -275,17 +275,15 @@ ParallelTcpStore::ParallelTcpStore(const std::string& host, const TCPStoreOption server_ = std::make_shared(initKey_, opts.port, opts.numWorkers); } } - // 检查环境变量 local_rank + char* local_rank_env = std::getenv("LOCAL_RANK"); - if (local_rank_env == nullptr) { - // 如果 LOCAL_RANK 环境变量不存在,则为 Proxy - proxy_ = std::make_unique("/tmp/torch_dist_store", host, opts.port); + if (local_rank_env == nullptr) { + proxy_ = std::make_unique("/tmp/torch_dist_store", host, opts.port); // if LOCAL_RANK not exist,agent->proxy proxy_->Start(); } else { - // 如果 LOCAL_RANK 环境变量存在,则为 Worker - Client_ = std::make_unique("/tmp/torch_dist_store"); - if (Client_.LocalConnect() != 0) { + client_ = std::make_unique("/tmp/torch_dist_store"); // if LOCAL_RANK exist,worker->client + if (client_.LocalConnect() != 0) { throw std::runtime_error{ std::string("connect local client to server failed.") }; } } @@ -302,8 +300,8 @@ ParallelTcpStore::~ParallelTcpStore() noexcept if (proxy_) { proxy_->Stop(); } - if (Client_) { - Client_->LocalClose(); + if (client_) { + client_->LocalClose(); } } @@ -312,7 +310,7 @@ void ParallelTcpStore::set(const std::string &key, const std::vector &v pta::StoreMessage request{ pta::MessageType::SET, key, value }; pta::StoreMessage response; std::lock_guard lockGuard{ clientMutex_ }; - auto ret = Client_.SyncCall(request, response); + auto ret = client_.SyncCall(request, response); if (ret != 0) { throw std::runtime_error{ std::string("set key ").append(key).append(" failed.") }; } @@ -421,8 +419,8 @@ int64_t ParallelTcpStore::IncreaseKey(const std::string &key, int64_t value) int ret = -1; if (proxy_) { ret = proxy_->SyncCall(request, response); - } else if (Client_) { - ret = Client_->SyncCall(request, response); + } else if (client_) { + ret = client_->SyncCall(request, response); } else { throw std::runtime_error{ "No valid client available for operation." }; } diff --git a/torch_npu/csrc/distributed/ParallelTcpStore.hpp b/torch_npu/csrc/distributed/ParallelTcpStore.hpp index ca0839baea..e468301cee 100644 --- a/torch_npu/csrc/distributed/ParallelTcpStore.hpp +++ b/torch_npu/csrc/distributed/ParallelTcpStore.hpp @@ -34,7 +34,7 @@ using CallBackFn = std::function; class ParallelStoreServer { public: ParallelStoreServer(std::string initKey, uint16_t port, c10::optional numWorkers); - ParallelStoreServer(std::string socketPath, CallBackFn callback); + ParallelStoreServer(std::string localSocketPath, CallBackFn callback); virtual ~ParallelStoreServer() noexcept; void WaitWorkers(const std::chrono::milliseconds &timeout) noexcept; @@ -53,7 +53,7 @@ private: private: CallBackFn callback_; - std::string socketPath_; + std::string localSocketPath_; using RequestHandler = std::function; std::unique_ptr server_; std::unordered_map requestHandlers_; @@ -95,7 +95,7 @@ private: private: std::unique_ptr proxy_; - std::unique_ptr Client_; + std::unique_ptr client_; std::shared_ptr server_; std::mutex clientMutex_; std::condition_variable initWaitCond_; diff --git a/torch_npu/csrc/distributed/proxy.cpp b/torch_npu/csrc/distributed/proxy.cpp index fc4ba332e8..4b81bb5f62 100644 --- a/torch_npu/csrc/distributed/proxy.cpp +++ b/torch_npu/csrc/distributed/proxy.cpp @@ -6,19 +6,14 @@ namespace c10d { namespace pta { -Proxy::Proxy(const std::string& localSocketPath, const std::string& tcpHost, uint16_t tcpPort) +Proxy::Proxy(const std::string& localSocketPath, const std::string& host, uint16_t port) : localServer_(std::make_unique(localSocketPath, [this](const StoreMessage& msg) { return this->HandleLocalServerMessage(msg); } )), - tcpClient_(std::make_unique(tcpHost, tcpPort)) + tcpClient_(std::make_unique(host, port)) { } -Proxy::~Proxy() -{ - Stop(); -} - void Proxy::Start() { if (tcpClient_->Connect() != 0) { diff --git a/torch_npu/csrc/distributed/proxy.hpp b/torch_npu/csrc/distributed/proxy.hpp index fa0f4189a3..6f73082bb6 100644 --- a/torch_npu/csrc/distributed/proxy.hpp +++ b/torch_npu/csrc/distributed/proxy.hpp @@ -10,7 +10,7 @@ namespace pta { class Proxy { public: - Proxy(const std::string& localSocketPath, const std::string& tcpHost, uint16_t tcpPort); + Proxy(const std::string& localSocketPath, const std::string& host, uint16_t port); void Start(); void Stop(); int SyncCall(); -- Gitee From 37f17b36619cc75b618e3c801eac3cc5102e6e4b Mon Sep 17 00:00:00 2001 From: wuyangyu Date: Wed, 21 Aug 2024 14:31:45 +0800 Subject: [PATCH 07/14] change8.21 --- torch_npu/csrc/distributed/Client.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_npu/csrc/distributed/Client.cpp b/torch_npu/csrc/distributed/Client.cpp index 944e10595f..a776c47953 100644 --- a/torch_npu/csrc/distributed/Client.cpp +++ b/torch_npu/csrc/distributed/Client.cpp @@ -120,7 +120,7 @@ int Client::LocalClose() { return 0; } - LOG(ERROR) << "close socket to server(" << path_ << ") failed " << errno << " : " << + LOG(ERROR) << "close socket to server(" << localSocketPath_ << ") failed " << errno << " : " << strerror(errno); return ret; } -- Gitee From daf95bf13abaf2b865951f288813a0aa56c0dd12 Mon Sep 17 00:00:00 2001 From: wu-yangyu2022 Date: Thu, 22 Aug 2024 11:07:41 +0800 Subject: [PATCH 08/14] 8.22 --- torch_npu/csrc/distributed/Client.cpp | 39 ++++++++++--------- torch_npu/csrc/distributed/Client.hpp | 6 +-- .../csrc/distributed/ParallelTcpServer.cpp | 3 +- .../csrc/distributed/ParallelTcpStore.cpp | 10 ++--- .../csrc/distributed/ParallelTcpStore.hpp | 4 +- torch_npu/csrc/distributed/proxy.cpp | 9 +++-- torch_npu/csrc/distributed/proxy.hpp | 6 +-- 7 files changed, 40 insertions(+), 37 deletions(-) diff --git a/torch_npu/csrc/distributed/Client.cpp b/torch_npu/csrc/distributed/Client.cpp index a776c47953..84821bd858 100644 --- a/torch_npu/csrc/distributed/Client.cpp +++ b/torch_npu/csrc/distributed/Client.cpp @@ -31,13 +31,14 @@ namespace c10d { namespace pta { static constexpr uint32_t READ_BUF_SZ = 256; -Client::Client(const std::string& localSocketPath) +Client::Client(const std::string localSocketPath) : localSocketPath_(std::move(localSocketPath) ), socketFd_(-1) {} -Client::Client(const std::string& host, uint16_t port) +Client::Client(const std::string host, uint16_t port) : host_{ std::move(host) }, port_{ port }, socketFd_(-1) {} -int Client::Connect() { +int Client::Connect() noexcept +{ socketFd_ = socket(AF_INET, SOCK_STREAM, 0); if (socketFd_ < 0) { LOG(ERROR) << "create tcp client socket failed " << errno << " : " << strerror(errno); @@ -76,7 +77,21 @@ int Client::Connect() { return -1; } -int Client::LocalConnect() { +int Client::Close() noexcept +{ + auto ret = close(socketFd_); + if (ret == 0) { + socketFd_ = -1; + return 0; + } + + LOG(ERROR) << "close socket to server(" << host_ << ":" << port_ << ") failed " << errno << " : " << + strerror(errno); + return ret; +} + +int Client::LocalConnect() noexcept +{ socketFd_ = socket(AF_UNIX, SOCK_STREAM, 0); if (socketFd_ < 0) { LOG(ERROR) << "Create local socket failed: " << strerror(errno); @@ -113,19 +128,7 @@ int Client::LocalConnect() { } -int Client::LocalClose() { - auto ret = close(socketFd_); - if (ret == 0) { - socketFd_ = -1; - return 0; - } - - LOG(ERROR) << "close socket to server(" << localSocketPath_ << ") failed " << errno << " : " << - strerror(errno); - return ret; -} - -int Client::Close() noexcept +int Client::LocalClose() noexcept { auto ret = close(socketFd_); if (ret == 0) { @@ -133,7 +136,7 @@ int Client::Close() noexcept return 0; } - LOG(ERROR) << "close socket to server(" << host_ << ":" << port_ << ") failed " << errno << " : " << + LOG(ERROR) << "close socket to server(" << localSocketPath_ << ") failed " << errno << " : " << strerror(errno); return ret; } diff --git a/torch_npu/csrc/distributed/Client.hpp b/torch_npu/csrc/distributed/Client.hpp index 707e99fab7..cc8bff3ae9 100644 --- a/torch_npu/csrc/distributed/Client.hpp +++ b/torch_npu/csrc/distributed/Client.hpp @@ -25,11 +25,11 @@ namespace c10d { namespace pta { class Client { public: - Client(const std::string& localSocketPath); // for local client - Client(const std::string& host, uint16_t port); // for tcp client + Client(const std::string localSocketPath); // for local client + Client(const std::string host, uint16_t port); // for tcp client int Connect() noexcept; - int LocalConnect() noexcept; int Close() noexcept; + int LocalConnect() noexcept; int LocalClose() noexcept; int SyncCall(const StoreMessage &request, StoreMessage &response) noexcept; int SetReceiveTimeout(const std::chrono::milliseconds &value) const noexcept; diff --git a/torch_npu/csrc/distributed/ParallelTcpServer.cpp b/torch_npu/csrc/distributed/ParallelTcpServer.cpp index 1014f3c309..fb62604aea 100644 --- a/torch_npu/csrc/distributed/ParallelTcpServer.cpp +++ b/torch_npu/csrc/distributed/ParallelTcpServer.cpp @@ -15,6 +15,7 @@ */ #include #include +#include #include #include #include @@ -106,7 +107,7 @@ ParallelTcpServer::ParallelTcpServer(uint32_t threadNum, uint16_t port, ServerPr : threadNum_{ std::max(4U, threadNum) }, port_{ port }, process_{ std::move(process) } {} -ParallelTcpServer::ParallelTcpServer(uint32_t threadNum, const std::string& localSocketPath, CallBackFn callback) noexcept +ParallelTcpServer::ParallelTcpServer(uint32_t threadNum, const std::string localSocketPath, CallBackFn callback) noexcept : threadNum_{ std::max(4U, threadNum) }, localSocketPath_{ localSocketPath }, callback_{ std::move(callback) } { isLocalServer_ = true; diff --git a/torch_npu/csrc/distributed/ParallelTcpStore.cpp b/torch_npu/csrc/distributed/ParallelTcpStore.cpp index deb8e5face..34395982c4 100644 --- a/torch_npu/csrc/distributed/ParallelTcpStore.cpp +++ b/torch_npu/csrc/distributed/ParallelTcpStore.cpp @@ -282,7 +282,7 @@ ParallelTcpStore::ParallelTcpStore(const std::string& host, const TCPStoreOption proxy_ = std::make_unique("/tmp/torch_dist_store", host, opts.port); // if LOCAL_RANK not exist,agent->proxy proxy_->Start(); } else { - client_ = std::make_unique("/tmp/torch_dist_store"); // if LOCAL_RANK exist,worker->client + client_("/tmp/torch_dist_store"); // if LOCAL_RANK exist,worker->client if (client_.LocalConnect() != 0) { throw std::runtime_error{ std::string("connect local client to server failed.") }; } @@ -300,8 +300,8 @@ ParallelTcpStore::~ParallelTcpStore() noexcept if (proxy_) { proxy_->Stop(); } - if (client_) { - client_->LocalClose(); + else { + client_.LocalClose(); } } @@ -419,10 +419,8 @@ int64_t ParallelTcpStore::IncreaseKey(const std::string &key, int64_t value) int ret = -1; if (proxy_) { ret = proxy_->SyncCall(request, response); - } else if (client_) { - ret = client_->SyncCall(request, response); } else { - throw std::runtime_error{ "No valid client available for operation." }; + ret = client_.SyncCall(request, response); } if (ret != 0) { diff --git a/torch_npu/csrc/distributed/ParallelTcpStore.hpp b/torch_npu/csrc/distributed/ParallelTcpStore.hpp index e468301cee..896b95d97e 100644 --- a/torch_npu/csrc/distributed/ParallelTcpStore.hpp +++ b/torch_npu/csrc/distributed/ParallelTcpStore.hpp @@ -93,9 +93,9 @@ private: static std::shared_ptr GetSharedServer(const std::string &initKey, uint16_t port, c10::optional numWorkers); -private: +private: + pta::Client client_; std::unique_ptr proxy_; - std::unique_ptr client_; std::shared_ptr server_; std::mutex clientMutex_; std::condition_variable initWaitCond_; diff --git a/torch_npu/csrc/distributed/proxy.cpp b/torch_npu/csrc/distributed/proxy.cpp index 4b81bb5f62..ecede119a7 100644 --- a/torch_npu/csrc/distributed/proxy.cpp +++ b/torch_npu/csrc/distributed/proxy.cpp @@ -1,16 +1,17 @@ #include "Proxy.hpp" -#include "ParallelStoreServer.hpp" +#include "ParallelTcpStorer.hpp" +#include "Client.hpp" #include "c10/util/Exception.h" #include namespace c10d { namespace pta { -Proxy::Proxy(const std::string& localSocketPath, const std::string& host, uint16_t port) - : localServer_(std::make_unique(localSocketPath, +Proxy::Proxy(const std::string localSocketPath, const std::string host, uint16_t port) + : localServer_(std::make_unique(localSocketPath, [this](const StoreMessage& msg) { return this->HandleLocalServerMessage(msg); } )), - tcpClient_(std::make_unique(host, port)) + tcpClient_(std::make_unique(host, port)) { } diff --git a/torch_npu/csrc/distributed/proxy.hpp b/torch_npu/csrc/distributed/proxy.hpp index 6f73082bb6..11c498a8f6 100644 --- a/torch_npu/csrc/distributed/proxy.hpp +++ b/torch_npu/csrc/distributed/proxy.hpp @@ -10,7 +10,7 @@ namespace pta { class Proxy { public: - Proxy(const std::string& localSocketPath, const std::string& host, uint16_t port); + Proxy(const std::string localSocketPath, const std::string host, uint16_t port); void Start(); void Stop(); int SyncCall(); @@ -18,8 +18,8 @@ public: private: void HandleLocalServerMessage(const StoreMessage& message); - std::unique_ptr localServer_; - std::unique_ptr tcpClient_; + std::unique_ptr localServer_; + std::unique_ptr tcpClient_; }; } // namespace pta -- Gitee From 4b207a469d235a3f233933f2e4cea1d5d8001ef0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=B4=E9=98=B3=E5=AE=87?= Date: Thu, 22 Aug 2024 03:10:40 +0000 Subject: [PATCH 09/14] =?UTF-8?q?=E9=87=8D=E5=91=BD=E5=90=8D=20torch=5Fnpu?= =?UTF-8?q?/csrc/distributed/proxy.cpp=20=E4=B8=BA=20torch=5Fnpu/csrc/dist?= =?UTF-8?q?ributed/Proxy.cpp?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/csrc/distributed/{proxy.cpp => Proxy.cpp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename torch_npu/csrc/distributed/{proxy.cpp => Proxy.cpp} (100%) diff --git a/torch_npu/csrc/distributed/proxy.cpp b/torch_npu/csrc/distributed/Proxy.cpp similarity index 100% rename from torch_npu/csrc/distributed/proxy.cpp rename to torch_npu/csrc/distributed/Proxy.cpp -- Gitee From e6056427ff3eba36465bffeedc556bffb635c579 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=B4=E9=98=B3=E5=AE=87?= Date: Thu, 22 Aug 2024 03:10:52 +0000 Subject: [PATCH 10/14] =?UTF-8?q?=E9=87=8D=E5=91=BD=E5=90=8D=20torch=5Fnpu?= =?UTF-8?q?/csrc/distributed/proxy.hpp=20=E4=B8=BA=20torch=5Fnpu/csrc/dist?= =?UTF-8?q?ributed/Proxy.hpp?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/csrc/distributed/{proxy.hpp => Proxy.hpp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename torch_npu/csrc/distributed/{proxy.hpp => Proxy.hpp} (100%) diff --git a/torch_npu/csrc/distributed/proxy.hpp b/torch_npu/csrc/distributed/Proxy.hpp similarity index 100% rename from torch_npu/csrc/distributed/proxy.hpp rename to torch_npu/csrc/distributed/Proxy.hpp -- Gitee From 4c9152e55259c5deee9ae66aef6f17388a9b9c9e Mon Sep 17 00:00:00 2001 From: wuyangyu Date: Thu, 22 Aug 2024 19:31:06 +0800 Subject: [PATCH 11/14] 8.22debug --- torch_npu/csrc/distributed/Client.cpp | 25 +++++++++++++++---- torch_npu/csrc/distributed/Client.hpp | 7 ++++-- .../csrc/distributed/ParallelTcpServer.hpp | 10 ++++---- .../csrc/distributed/ParallelTcpStore.cpp | 7 +++--- .../csrc/distributed/ParallelTcpStore.hpp | 4 ++- torch_npu/csrc/distributed/Proxy.cpp | 3 ++- torch_npu/csrc/distributed/Proxy.hpp | 4 ++- 7 files changed, 42 insertions(+), 18 deletions(-) diff --git a/torch_npu/csrc/distributed/Client.cpp b/torch_npu/csrc/distributed/Client.cpp index 84821bd858..3d2fd30e40 100644 --- a/torch_npu/csrc/distributed/Client.cpp +++ b/torch_npu/csrc/distributed/Client.cpp @@ -26,16 +26,31 @@ #include "c10/util/Logging.h" #include "Client.hpp" +#include "StoreMessagePacker.hpp" namespace c10d { namespace pta { static constexpr uint32_t READ_BUF_SZ = 256; -Client::Client(const std::string localSocketPath) - : localSocketPath_(std::move(localSocketPath) ), socketFd_(-1) {} - -Client::Client(const std::string host, uint16_t port) - : host_{ std::move(host) }, port_{ port }, socketFd_(-1) {} +Client::Client(const std::string localSocketPath) noexcept + : host_{ "" }, port_{ 0 }, localSocketPath_{ std::move(localSocketPath) }, socketFd_(-1) {} + +Client::Client(const std::string host, uint16_t port) noexcept + : host_{ std::move(host) }, port_{ port }, localSocketPath_{""}, socketFd_(-1) {} +// Client::Client() noexcept +// : host_{ "" }, port_{ 0 }, localSocketPath_{ "" }, socketFd(-1) +// {} + +// void Client::init(const std::string localSocketPath) noexcept +// { +// localSocketPath_ = localSocketPath; +// } + +// void Client::init(const std::string host, uint16_t port) noexcept +// { +// host_ = host; +// port_ = port; +// } int Client::Connect() noexcept { diff --git a/torch_npu/csrc/distributed/Client.hpp b/torch_npu/csrc/distributed/Client.hpp index cc8bff3ae9..a12c9a8546 100644 --- a/torch_npu/csrc/distributed/Client.hpp +++ b/torch_npu/csrc/distributed/Client.hpp @@ -25,8 +25,11 @@ namespace c10d { namespace pta { class Client { public: - Client(const std::string localSocketPath); // for local client - Client(const std::string host, uint16_t port); // for tcp client + Client(const std::string localSocketPath) noexcept; // for local client + Client(const std::string host, uint16_t port) noexcept; // for tcp client + // Client() noexcept; + // void init(const std::string localSocketPath) noexcept; + // void init(const std::string host, uint16_t port) noexcept; int Connect() noexcept; int Close() noexcept; int LocalConnect() noexcept; diff --git a/torch_npu/csrc/distributed/ParallelTcpServer.hpp b/torch_npu/csrc/distributed/ParallelTcpServer.hpp index 2723e828c3..23b050eb49 100644 --- a/torch_npu/csrc/distributed/ParallelTcpServer.hpp +++ b/torch_npu/csrc/distributed/ParallelTcpServer.hpp @@ -134,11 +134,11 @@ private: static int SetNonBlocking(int fd) noexcept; private: - const uint32_t threadNum_; - const std::uint16_t port_; - const std::string localSocketPath_; - const ServerProcFn process_; - const CallBackFn callback_; + const uint32_t threadNum_{ 0 }; + const std::uint16_t port_{ 0 }; + const std::string localSocketPath_{}; + const ServerProcFn process_{}; + const CallBackFn callback_{}; int listenSocket_{ -1 }; int epCtlFd_{ -1 }; bool isLocalServer_{ false }; diff --git a/torch_npu/csrc/distributed/ParallelTcpStore.cpp b/torch_npu/csrc/distributed/ParallelTcpStore.cpp index 34395982c4..bc70b7a20c 100644 --- a/torch_npu/csrc/distributed/ParallelTcpStore.cpp +++ b/torch_npu/csrc/distributed/ParallelTcpStore.cpp @@ -15,6 +15,7 @@ */ #include "ParallelTcpServer.hpp" #include "ParallelTcpStore.hpp" +#include "Proxy.hpp" namespace c10d { namespace pta { @@ -265,7 +266,7 @@ bool ParallelStoreServer::CheckAllKeysExistInLock(const std::vector std::mutex ParallelTcpStore::cacheServerMutex_; std::unordered_map> ParallelTcpStore::cachedServers_; -ParallelTcpStore::ParallelTcpStore(const std::string& host, const TCPStoreOptions& opts) +ParallelTcpStore::ParallelTcpStore(const std::string &host, const c10d::TCPStoreOptions &opts) : Store(opts.timeout) { if (opts.isServer) { @@ -279,10 +280,10 @@ ParallelTcpStore::ParallelTcpStore(const std::string& host, const TCPStoreOption char* local_rank_env = std::getenv("LOCAL_RANK"); if (local_rank_env == nullptr) { - proxy_ = std::make_unique("/tmp/torch_dist_store", host, opts.port); // if LOCAL_RANK not exist,agent->proxy + proxy_ = std::make_unique("/mnt/torch_dist_store", host, opts.port); // if LOCAL_RANK not exist,agent->proxy proxy_->Start(); } else { - client_("/tmp/torch_dist_store"); // if LOCAL_RANK exist,worker->client + client_("/mnt/torch_dist_store"); // if LOCAL_RANK exist,worker->client if (client_.LocalConnect() != 0) { throw std::runtime_error{ std::string("connect local client to server failed.") }; } diff --git a/torch_npu/csrc/distributed/ParallelTcpStore.hpp b/torch_npu/csrc/distributed/ParallelTcpStore.hpp index 896b95d97e..5a65d6d0ff 100644 --- a/torch_npu/csrc/distributed/ParallelTcpStore.hpp +++ b/torch_npu/csrc/distributed/ParallelTcpStore.hpp @@ -26,10 +26,12 @@ #include "c10d/TCPStore.hpp" #include "Client.hpp" -#include "Proxy.hpp" #include "ParallelTcpServer.hpp" namespace c10d { namespace pta { + +class Proxy; + using CallBackFn = std::function; class ParallelStoreServer { public: diff --git a/torch_npu/csrc/distributed/Proxy.cpp b/torch_npu/csrc/distributed/Proxy.cpp index ecede119a7..63fb7f1ab9 100644 --- a/torch_npu/csrc/distributed/Proxy.cpp +++ b/torch_npu/csrc/distributed/Proxy.cpp @@ -1,8 +1,9 @@ #include "Proxy.hpp" -#include "ParallelTcpStorer.hpp" +#include "ParallelTcpStore.hpp" #include "Client.hpp" #include "c10/util/Exception.h" #include +#include "StoreMessagePacker.hpp" namespace c10d { namespace pta { diff --git a/torch_npu/csrc/distributed/Proxy.hpp b/torch_npu/csrc/distributed/Proxy.hpp index 11c498a8f6..93a5157e46 100644 --- a/torch_npu/csrc/distributed/Proxy.hpp +++ b/torch_npu/csrc/distributed/Proxy.hpp @@ -2,12 +2,14 @@ #include #include -#include "ParallelTcpStore.hpp" #include "Client.hpp" +#include "StoreMessagePacker.hpp" namespace c10d { namespace pta { +class ParallelStoreServer; + class Proxy { public: Proxy(const std::string localSocketPath, const std::string host, uint16_t port); -- Gitee From 25aa3f97735edf217aa9347648835278482a866c Mon Sep 17 00:00:00 2001 From: wuyangyu Date: Tue, 27 Aug 2024 17:02:18 +0800 Subject: [PATCH 12/14] debug --- torch_npu/csrc/distributed/Client.cpp | 25 ++++++------------- torch_npu/csrc/distributed/Client.hpp | 12 ++++----- .../csrc/distributed/ParallelTcpServer.hpp | 6 ++--- .../csrc/distributed/ParallelTcpStore.cpp | 7 +++--- torch_npu/csrc/distributed/Proxy.cpp | 14 +++++------ torch_npu/csrc/distributed/Proxy.hpp | 9 ++++--- 6 files changed, 31 insertions(+), 42 deletions(-) diff --git a/torch_npu/csrc/distributed/Client.cpp b/torch_npu/csrc/distributed/Client.cpp index 3d2fd30e40..f9a3709d27 100644 --- a/torch_npu/csrc/distributed/Client.cpp +++ b/torch_npu/csrc/distributed/Client.cpp @@ -32,25 +32,14 @@ namespace c10d { namespace pta { static constexpr uint32_t READ_BUF_SZ = 256; -Client::Client(const std::string localSocketPath) noexcept - : host_{ "" }, port_{ 0 }, localSocketPath_{ std::move(localSocketPath) }, socketFd_(-1) {} - Client::Client(const std::string host, uint16_t port) noexcept - : host_{ std::move(host) }, port_{ port }, localSocketPath_{""}, socketFd_(-1) {} -// Client::Client() noexcept -// : host_{ "" }, port_{ 0 }, localSocketPath_{ "" }, socketFd(-1) -// {} - -// void Client::init(const std::string localSocketPath) noexcept -// { -// localSocketPath_ = localSocketPath; -// } - -// void Client::init(const std::string host, uint16_t port) noexcept -// { -// host_ = host; -// port_ = port; -// } + : host_{ std::move(host) }, port_{ port }, socketFd_(-1) {} +Client::Client() noexcept {} + +void Client::init(std::string localSocketPath) noexcept +{ + localSocketPath_ = localSocketPath; +} int Client::Connect() noexcept { diff --git a/torch_npu/csrc/distributed/Client.hpp b/torch_npu/csrc/distributed/Client.hpp index a12c9a8546..34b598e2fd 100644 --- a/torch_npu/csrc/distributed/Client.hpp +++ b/torch_npu/csrc/distributed/Client.hpp @@ -25,11 +25,9 @@ namespace c10d { namespace pta { class Client { public: - Client(const std::string localSocketPath) noexcept; // for local client Client(const std::string host, uint16_t port) noexcept; // for tcp client - // Client() noexcept; - // void init(const std::string localSocketPath) noexcept; - // void init(const std::string host, uint16_t port) noexcept; + Client() noexcept; + void init(std::string localSocketPath) noexcept; int Connect() noexcept; int Close() noexcept; int LocalConnect() noexcept; @@ -38,9 +36,9 @@ public: int SetReceiveTimeout(const std::chrono::milliseconds &value) const noexcept; private: - const std::string localSocketPath_; - const std::string host_; - const uint16_t port_; + std::string localSocketPath_{}; + const std::string host_{}; + const uint16_t port_{0}; int socketFd_; }; } // pta diff --git a/torch_npu/csrc/distributed/ParallelTcpServer.hpp b/torch_npu/csrc/distributed/ParallelTcpServer.hpp index 23b050eb49..f17a5281ff 100644 --- a/torch_npu/csrc/distributed/ParallelTcpServer.hpp +++ b/torch_npu/csrc/distributed/ParallelTcpServer.hpp @@ -100,7 +100,7 @@ using CallBackFn = std::function; class ParallelTcpServer { public: explicit ParallelTcpServer(uint32_t threadNum, uint16_t port, ServerProcFn process) noexcept; - explicit ParallelTcpServer(uint32_t threadNum, const std::string& localSocketPath, CallBackFn callback) noexcept; + explicit ParallelTcpServer(uint32_t threadNum, const std::string localSocketPath, CallBackFn callback) noexcept; int Start() noexcept; int LocalStart() noexcept; @@ -137,8 +137,8 @@ private: const uint32_t threadNum_{ 0 }; const std::uint16_t port_{ 0 }; const std::string localSocketPath_{}; - const ServerProcFn process_{}; - const CallBackFn callback_{}; + const ServerProcFn process_{ nullptr }; + const CallBackFn callback_{ nullptr }; int listenSocket_{ -1 }; int epCtlFd_{ -1 }; bool isLocalServer_{ false }; diff --git a/torch_npu/csrc/distributed/ParallelTcpStore.cpp b/torch_npu/csrc/distributed/ParallelTcpStore.cpp index bc70b7a20c..1e8721b2ce 100644 --- a/torch_npu/csrc/distributed/ParallelTcpStore.cpp +++ b/torch_npu/csrc/distributed/ParallelTcpStore.cpp @@ -16,6 +16,7 @@ #include "ParallelTcpServer.hpp" #include "ParallelTcpStore.hpp" #include "Proxy.hpp" +#include "Client.hpp" namespace c10d { namespace pta { @@ -283,10 +284,10 @@ ParallelTcpStore::ParallelTcpStore(const std::string &host, const c10d::TCPStore proxy_ = std::make_unique("/mnt/torch_dist_store", host, opts.port); // if LOCAL_RANK not exist,agent->proxy proxy_->Start(); } else { - client_("/mnt/torch_dist_store"); // if LOCAL_RANK exist,worker->client + client_.init("/mnt/torch_dist_store"); // if LOCAL_RANK exist,worker->client if (client_.LocalConnect() != 0) { - throw std::runtime_error{ std::string("connect local client to server failed.") }; - } + throw std::runtime_error{ std::string("connect local client to server failed.") }; + } } if (opts.waitWorkers) { IncreaseKey(initKey_, 1); diff --git a/torch_npu/csrc/distributed/Proxy.cpp b/torch_npu/csrc/distributed/Proxy.cpp index 63fb7f1ab9..d8fb06f23a 100644 --- a/torch_npu/csrc/distributed/Proxy.cpp +++ b/torch_npu/csrc/distributed/Proxy.cpp @@ -9,30 +9,30 @@ namespace c10d { namespace pta { Proxy::Proxy(const std::string localSocketPath, const std::string host, uint16_t port) - : localServer_(std::make_unique(localSocketPath, + : localServer_{std::make_unique(localSocketPath, [this](const StoreMessage& msg) { return this->HandleLocalServerMessage(msg); } - )), - tcpClient_(std::make_unique(host, port)) + )}, + tcpClient_{std::make_unique(host, port)} { } -void Proxy::Start() +void Proxy::Start() noexcept { if (tcpClient_->Connect() != 0) { throw std::runtime_error("Failed to connect to TCP server"); } } -void Proxy::Stop() +void Proxy::Stop() noexcept { tcpClient_->Close(); } -int Proxy::SyncCall(const pta::StoreMessage& request, pta::StoreMessage& response) { +int Proxy::SyncCall(const pta::StoreMessage& request, pta::StoreMessage& response) noexcept { return tcpClient_->SyncCall(request, response); } -StoreMessage Proxy::HandleLocalServerMessage(const StoreMessage& message) +StoreMessage Proxy::HandleLocalServerMessage(const StoreMessage& message) noexcept { StoreMessage response; if (tcpClient_->SyncCall(message, response) != 0) { diff --git a/torch_npu/csrc/distributed/Proxy.hpp b/torch_npu/csrc/distributed/Proxy.hpp index 93a5157e46..671ebb68a8 100644 --- a/torch_npu/csrc/distributed/Proxy.hpp +++ b/torch_npu/csrc/distributed/Proxy.hpp @@ -12,10 +12,11 @@ class ParallelStoreServer; class Proxy { public: - Proxy(const std::string localSocketPath, const std::string host, uint16_t port); - void Start(); - void Stop(); - int SyncCall(); + Proxy( std::string localSocketPath, const std::string host, uint16_t port); + void Start() noexcept; + void Stop() noexcept; + int SyncCall(const pta::StoreMessage& request, pta::StoreMessage& response) noexcept; + StoreMessage HandleLocalServerMessage(const StoreMessage& message) noexcept; private: void HandleLocalServerMessage(const StoreMessage& message); -- Gitee From a9279b4a5a66351d73dbe177c3062a274808ef15 Mon Sep 17 00:00:00 2001 From: wuyangyu Date: Wed, 4 Sep 2024 16:12:37 +0800 Subject: [PATCH 13/14] clean_code --- torch_npu/csrc/distributed/Init.cpp | 5 +- .../{Proxy.cpp => ParallelStoreProxy.cpp} | 21 +- .../{Proxy.hpp => ParallelStoreProxy.hpp} | 12 +- .../csrc/distributed/ParallelTcpServer.cpp | 28 +- .../csrc/distributed/ParallelTcpServer.hpp | 17 +- .../csrc/distributed/ParallelTcpStore.cpp | 273 +++++++++++------- .../csrc/distributed/ParallelTcpStore.hpp | 16 +- .../{Client.cpp => StoreClient.cpp} | 26 +- .../{Client.hpp => StoreClient.hpp} | 7 +- torch_npu/distributed/rendezvous.py | 15 +- 10 files changed, 236 insertions(+), 184 deletions(-) rename torch_npu/csrc/distributed/{Proxy.cpp => ParallelStoreProxy.cpp} (60%) rename torch_npu/csrc/distributed/{Proxy.hpp => ParallelStoreProxy.hpp} (54%) rename torch_npu/csrc/distributed/{Client.cpp => StoreClient.cpp} (90%) rename torch_npu/csrc/distributed/{Client.hpp => StoreClient.hpp} (91%) diff --git a/torch_npu/csrc/distributed/Init.cpp b/torch_npu/csrc/distributed/Init.cpp index f4500b828b..a75767f0c4 100644 --- a/torch_npu/csrc/distributed/Init.cpp +++ b/torch_npu/csrc/distributed/Init.cpp @@ -476,6 +476,7 @@ Example:: .def(py::init([](const std::string &host, uint16_t port, int worldSize, + int16_t localRank, bool isServer, std::chrono::milliseconds timeout, bool waitWorkers, @@ -486,8 +487,8 @@ Example:: } ::c10d::TCPStoreOptions opts{ port, isServer, numWorkers, waitWorkers, timeout, multiTenant }; - return c10::make_intrusive <::c10d::ParallelTcpStore>(host, opts); - }), py::arg("host") = "127.0.0.1", py::arg("port") = 29500, py::arg("world_size") = -1, + return c10::make_intrusive <::c10d::ParallelTcpStore>(host, localRank, opts); + }), py::arg("host") = "127.0.0.1", py::arg("port") = 29500, py::arg("world_size") = -1, py::arg("localRank") = -2, py::arg("is_server") = false, py::arg("timeout") = std::chrono::milliseconds(300000), py::arg("wait_workers") = true, py::arg("multi_tenant") = false); diff --git a/torch_npu/csrc/distributed/Proxy.cpp b/torch_npu/csrc/distributed/ParallelStoreProxy.cpp similarity index 60% rename from torch_npu/csrc/distributed/Proxy.cpp rename to torch_npu/csrc/distributed/ParallelStoreProxy.cpp index d8fb06f23a..647120de44 100644 --- a/torch_npu/csrc/distributed/Proxy.cpp +++ b/torch_npu/csrc/distributed/ParallelStoreProxy.cpp @@ -1,20 +1,17 @@ -#include "Proxy.hpp" +#include "ParallelStoreProxy.hpp" #include "ParallelTcpStore.hpp" -#include "Client.hpp" +#include "StoreClient.hpp" #include "c10/util/Exception.h" #include #include "StoreMessagePacker.hpp" namespace c10d { namespace pta { - Proxy::Proxy(const std::string localSocketPath, const std::string host, uint16_t port) - : localServer_{std::make_unique(localSocketPath, - [this](const StoreMessage& msg) { return this->HandleLocalServerMessage(msg); } - )}, - tcpClient_{std::make_unique(host, port)} -{ -} + : localServer_{ std::make_unique(localSocketPath, + [this](const StoreMessage &msg) { return this->HandleLocalServerMessage(msg); }) }, + tcpClient_{ std::make_unique(host, port) } +{} void Proxy::Start() noexcept { @@ -28,11 +25,12 @@ void Proxy::Stop() noexcept tcpClient_->Close(); } -int Proxy::SyncCall(const pta::StoreMessage& request, pta::StoreMessage& response) noexcept { +int Proxy::SyncCall(const pta::StoreMessage &request, pta::StoreMessage &response) noexcept +{ return tcpClient_->SyncCall(request, response); } -StoreMessage Proxy::HandleLocalServerMessage(const StoreMessage& message) noexcept +StoreMessage Proxy::HandleLocalServerMessage(const StoreMessage &message) noexcept { StoreMessage response; if (tcpClient_->SyncCall(message, response) != 0) { @@ -40,6 +38,5 @@ StoreMessage Proxy::HandleLocalServerMessage(const StoreMessage& message) noexce } return response; } - } // namespace pta } // namespace c10d \ No newline at end of file diff --git a/torch_npu/csrc/distributed/Proxy.hpp b/torch_npu/csrc/distributed/ParallelStoreProxy.hpp similarity index 54% rename from torch_npu/csrc/distributed/Proxy.hpp rename to torch_npu/csrc/distributed/ParallelStoreProxy.hpp index 671ebb68a8..c1fa0a5b85 100644 --- a/torch_npu/csrc/distributed/Proxy.hpp +++ b/torch_npu/csrc/distributed/ParallelStoreProxy.hpp @@ -2,28 +2,24 @@ #include #include -#include "Client.hpp" +#include "StoreClient.hpp" #include "StoreMessagePacker.hpp" namespace c10d { namespace pta { - class ParallelStoreServer; class Proxy { public: - Proxy( std::string localSocketPath, const std::string host, uint16_t port); + Proxy(std::string localSocketPath, const std::string host, uint16_t port); void Start() noexcept; void Stop() noexcept; - int SyncCall(const pta::StoreMessage& request, pta::StoreMessage& response) noexcept; - StoreMessage HandleLocalServerMessage(const StoreMessage& message) noexcept; + int SyncCall(const pta::StoreMessage &request, pta::StoreMessage &response) noexcept; + StoreMessage HandleLocalServerMessage(const StoreMessage &message) noexcept; private: - void HandleLocalServerMessage(const StoreMessage& message); - std::unique_ptr localServer_; std::unique_ptr tcpClient_; }; - } // namespace pta } // namespace c10d \ No newline at end of file diff --git a/torch_npu/csrc/distributed/ParallelTcpServer.cpp b/torch_npu/csrc/distributed/ParallelTcpServer.cpp index fb62604aea..8203c10a9b 100644 --- a/torch_npu/csrc/distributed/ParallelTcpServer.cpp +++ b/torch_npu/csrc/distributed/ParallelTcpServer.cpp @@ -1,4 +1,4 @@ -/** +/* * * @copyright Copyright (c) 2024 Huawei Technologies Co., Ltd. All rights reserved. * * Licensed under the BSD 3-Clause License (the "License"); @@ -107,8 +107,9 @@ ParallelTcpServer::ParallelTcpServer(uint32_t threadNum, uint16_t port, ServerPr : threadNum_{ std::max(4U, threadNum) }, port_{ port }, process_{ std::move(process) } {} -ParallelTcpServer::ParallelTcpServer(uint32_t threadNum, const std::string localSocketPath, CallBackFn callback) noexcept - : threadNum_{ std::max(4U, threadNum) }, localSocketPath_{ localSocketPath }, callback_{ std::move(callback) } +ParallelTcpServer::ParallelTcpServer(uint32_t threadNum, const std::string localSocketPath, + ServerProcFn process) noexcept + : threadNum_{ std::max(4U, threadNum) }, localSocketPath_{ localSocketPath }, process_{ std::move(process) } { isLocalServer_ = true; } @@ -121,11 +122,11 @@ int ParallelTcpServer::Start() noexcept return -1; } - if(isLocalServer_){ - listenSocket_ = CreateLocalSocket(localSocketPath_); - }else{ + if (isLocalServer_) { + listenSocket_ = CreateLocalSocket(localSocketPath_); + } else { listenSocket_ = CreateSocket(port_); - } + } if (listenSocket_ < 0) { delete[] buffer_; buffer_ = nullptr; @@ -207,8 +208,8 @@ void ParallelTcpServer::WakeupWaitingClients(const std::string &key) noexcept keyWaitingSockets_.erase(key); lockGuard.unlock(); - std::vector body{static_cast(MessageWaitKeyRes::KEYS_STOP_WAITING)}; - StoreMessage response{MessageType::WAIT, body}; + std::vector body{ static_cast(MessageWaitKeyRes::KEYS_STOP_WAITING) }; + StoreMessage response{ MessageType::WAIT, body }; auto buf = StoreMessagePacker::Pack(response); for (auto socket : stopWaitingSockets) { write(socket, buf.data(), buf.size()); @@ -250,11 +251,12 @@ int ParallelTcpServer::CreateSocket(uint16_t port) noexcept return sockFd; } -int ParallelTcpServer::CreateLocalSocket(const std::string& localSocketPath) noexcept +int ParallelTcpServer::CreateLocalSocket(const std::string &localSocketPath) noexcept { struct sockaddr_un servAddr {}; servAddr.sun_family = AF_UNIX; - strncpy(servAddr.sun_path, localSocketPath.c_str(), sizeof(servAddr.sun_path) - 1); + servAddr.sun_path[0] = '\0'; + strncpy(servAddr.sun_path + 1, localSocketPath.c_str(), sizeof(servAddr.sun_path) - 1); auto sockFd = ::socket(AF_UNIX, SOCK_STREAM, 0); if (sockFd < 0) { @@ -262,7 +264,7 @@ int ParallelTcpServer::CreateLocalSocket(const std::string& localSocketPath) noe return -1; } - unlink(localSocketPath.c_str()); // Remove any existing socket file + unlink(localSocketPath.c_str()); // Remove any existing socket file auto ret = ::bind(sockFd, reinterpret_cast(&servAddr), sizeof(servAddr)); if (ret != 0) { @@ -423,7 +425,7 @@ void ParallelTcpServer::ProcessClientEvent(int epFd, int fd, uint32_t event, if (event & EPOLLIN) { pos->second.ReceiveData(); while (pos->second.HasNextReq()) { - auto response = isLocalServer_ ? callback_(pos->second.NextRequest()) : process_(fd, pos->second.NextRequest()); + auto response = process_(fd, pos->second.NextRequest()); pos->second.SendResponse(response); } diff --git a/torch_npu/csrc/distributed/ParallelTcpServer.hpp b/torch_npu/csrc/distributed/ParallelTcpServer.hpp index f17a5281ff..aa33b3dc06 100644 --- a/torch_npu/csrc/distributed/ParallelTcpServer.hpp +++ b/torch_npu/csrc/distributed/ParallelTcpServer.hpp @@ -1,4 +1,4 @@ -/** +/* * * @copyright Copyright (c) 2024 Huawei Technologies Co., Ltd. All rights reserved. * * Licensed under the BSD 3-Clause License (the "License"); @@ -27,7 +27,7 @@ namespace c10d { namespace pta { -/** +/* * * @brief wrapper for pthread_spinlock_t */ class SpinLock { @@ -61,7 +61,7 @@ private: pthread_spinlock_t spinlock_{}; }; -/** +/* * * @brief store client IO context for server. */ class ClientIoContext { @@ -92,15 +92,15 @@ private: }; using ServerProcFn = std::function; -using CallBackFn = std::function; +using CallBackFn = std::function; -/** +/* * * @brief epoll based TCP server with registered message processor. */ class ParallelTcpServer { -public: +public: explicit ParallelTcpServer(uint32_t threadNum, uint16_t port, ServerProcFn process) noexcept; - explicit ParallelTcpServer(uint32_t threadNum, const std::string localSocketPath, CallBackFn callback) noexcept; + explicit ParallelTcpServer(uint32_t threadNum, const std::string localSocketPath, ServerProcFn process) noexcept; int Start() noexcept; int LocalStart() noexcept; @@ -119,7 +119,7 @@ public: private: static int CreateSocket(uint16_t port) noexcept; - static int CreateLocalSocket(const std::string& localSocketPath) noexcept; + static int CreateLocalSocket(const std::string &localSocketPath) noexcept; static int CreateEpoll(int targetFd = -1) noexcept; @@ -138,7 +138,6 @@ private: const std::uint16_t port_{ 0 }; const std::string localSocketPath_{}; const ServerProcFn process_{ nullptr }; - const CallBackFn callback_{ nullptr }; int listenSocket_{ -1 }; int epCtlFd_{ -1 }; bool isLocalServer_{ false }; diff --git a/torch_npu/csrc/distributed/ParallelTcpStore.cpp b/torch_npu/csrc/distributed/ParallelTcpStore.cpp index 1e8721b2ce..fe889fae35 100644 --- a/torch_npu/csrc/distributed/ParallelTcpStore.cpp +++ b/torch_npu/csrc/distributed/ParallelTcpStore.cpp @@ -1,4 +1,4 @@ -/** +/* * * @copyright Copyright (c) 2024 Huawei Technologies Co., Ltd. All rights reserved. * * Licensed under the BSD 3-Clause License (the "License"); @@ -15,8 +15,8 @@ */ #include "ParallelTcpServer.hpp" #include "ParallelTcpStore.hpp" -#include "Proxy.hpp" -#include "Client.hpp" +#include "ParallelStoreProxy.hpp" +#include "StoreClient.hpp" namespace c10d { namespace pta { @@ -47,9 +47,10 @@ ParallelStoreServer::ParallelStoreServer(std::string initKey, uint16_t port, c10 ParallelStoreServer::ParallelStoreServer(std::string localSocketPath, CallBackFn callback) : localSocketPath_(std::move(localSocketPath)), callback_(std::move(callback)) { - auto threadNum = 4U; + auto threadNum = 1U; + LocalInitializeHandlers(); server_ = std::make_unique(threadNum, localSocketPath_, - [this](const pta::StoreMessage &request) { return this->callback_(request); }); + [this](int fd, const pta::StoreMessage &request) { return ProcessRequest(fd, request); }); if (server_->Start() != 0) { throw std::runtime_error{ std::string("start local server on socket ").append(localSocketPath_).append(" failed.") @@ -258,6 +259,26 @@ void ParallelStoreServer::InitializeHandlers() noexcept [this](int fd, const pta::StoreMessage &req) { return ProcessDeleteRequest(fd, req); }); } +void ParallelStoreServer::LocalInitializeHandlers() noexcept +{ + requestHandlers_.emplace(pta::MessageType::SET, + [this](int fd, const pta::StoreMessage &req) { return callback_(req); }); + requestHandlers_.emplace(pta::MessageType::COMPARE_SET, + [this](int fd, const pta::StoreMessage &req) { return callback_(req); }); + requestHandlers_.emplace(pta::MessageType::GET, + [this](int fd, const pta::StoreMessage &req) { return callback_(req); }); + requestHandlers_.emplace(pta::MessageType::ADD, + [this](int fd, const pta::StoreMessage &req) { return callback_(req); }); + requestHandlers_.emplace(pta::MessageType::CHECK, + [this](int fd, const pta::StoreMessage &req) { return callback_(req); }); + requestHandlers_.emplace(pta::MessageType::WAIT, + [this](int fd, const pta::StoreMessage &req) { return callback_(req); }); + requestHandlers_.emplace(pta::MessageType::GET_NUM_KEYS, + [this](int fd, const pta::StoreMessage &req) { return callback_(req); }); + requestHandlers_.emplace(pta::MessageType::DELETE_KEY, + [this](int fd, const pta::StoreMessage &req) { return callback_(req); }); +} + bool ParallelStoreServer::CheckAllKeysExistInLock(const std::vector &keys) noexcept { return std::all_of(keys.begin(), keys.end(), [this](const std::string &key) { return keyStore_.count(key) > 0; }); @@ -267,7 +288,7 @@ bool ParallelStoreServer::CheckAllKeysExistInLock(const std::vector std::mutex ParallelTcpStore::cacheServerMutex_; std::unordered_map> ParallelTcpStore::cachedServers_; -ParallelTcpStore::ParallelTcpStore(const std::string &host, const c10d::TCPStoreOptions &opts) +ParallelTcpStore::ParallelTcpStore(const std::string &host, const int16_t &localRank, const c10d::TCPStoreOptions &opts) : Store(opts.timeout) { if (opts.isServer) { @@ -276,15 +297,13 @@ ParallelTcpStore::ParallelTcpStore(const std::string &host, const c10d::TCPStore } else { server_ = std::make_shared(initKey_, opts.port, opts.numWorkers); } - } + } - char* local_rank_env = std::getenv("LOCAL_RANK"); - - if (local_rank_env == nullptr) { - proxy_ = std::make_unique("/mnt/torch_dist_store", host, opts.port); // if LOCAL_RANK not exist,agent->proxy + if (localRank == -1) { + proxy_ = std::make_unique("/mnt/proxy/torch_dist_store", host, opts.port); proxy_->Start(); } else { - client_.init("/mnt/torch_dist_store"); // if LOCAL_RANK exist,worker->client + client_.init("/mnt/proxy/torch_dist_store"); if (client_.LocalConnect() != 0) { throw std::runtime_error{ std::string("connect local client to server failed.") }; } @@ -301,8 +320,7 @@ ParallelTcpStore::~ParallelTcpStore() noexcept { if (proxy_) { proxy_->Stop(); - } - else { + } else { client_.LocalClose(); } } @@ -312,7 +330,12 @@ void ParallelTcpStore::set(const std::string &key, const std::vector &v pta::StoreMessage request{ pta::MessageType::SET, key, value }; pta::StoreMessage response; std::lock_guard lockGuard{ clientMutex_ }; - auto ret = client_.SyncCall(request, response); + int ret = -1; + if (proxy_) { + ret = proxy_->SyncCall(request, response); + } else { + ret = client_.SyncCall(request, response); + } if (ret != 0) { throw std::runtime_error{ std::string("set key ").append(key).append(" failed.") }; } @@ -324,7 +347,12 @@ std::vector ParallelTcpStore::compareSet(const std::string &key, const pta::StoreMessage request{ pta::MessageType::COMPARE_SET, key, currentValue, newValue }; pta::StoreMessage response; std::lock_guard lockGuard{ clientMutex_ }; - auto ret = client_.SyncCall(request, response); + int ret = -1; + if (proxy_) { + ret = proxy_->SyncCall(request, response); + } else { + ret = client_.SyncCall(request, response); + } if (ret != 0) { throw std::runtime_error{ std::string("compare and set key ").append(key).append(" failed.") }; } @@ -340,122 +368,147 @@ std::vector ParallelTcpStore::get(const std::string &key) pta::StoreMessage getResp; std::lock_guard lockGuard{ clientMutex_ }; - auto ret = client_.SyncCall(waitReq, waitResp); - if (ret != 0) { - throw std::runtime_error{ std::string("get key ").append(key).append(" failed.") }; + int ret = -1; + if (proxy_) { + ret = proxy_->SyncCall(waitReq, waitResp); + if (ret != 0) { + throw std::runtime_error{ std::string("proxy_ sync wait msg ").append(" failed.") }; + ret = proxy_->SyncCall(getReq, getResp); + if (ret != 0) { + throw std::runtime_error{ std::string("proxy_ sync get msg ").append(" failed.") }; + } + } else { + ret = client_.SyncCall(waitReq, waitResp); + if (ret != 0) { + throw std::runtime_error{ std::string("get key ").append(key).append(" failed.") }; + } + + ret = client_.SyncCall(getReq, getResp); + if (ret != 0) { + throw std::runtime_error{ std::string("get key ").append(key).append(" failed.") }; + } + } + return getResp.values.empty() ? std::vector{} : std::move(getResp.values[0]); } - ret = client_.SyncCall(getReq, getResp); - if (ret != 0) { - throw std::runtime_error{ std::string("get key ").append(key).append(" failed.") }; + int64_t ParallelTcpStore::add(const std::string &key, int64_t value) + { + return IncreaseKey(key, value); } - return getResp.values.empty() ? std::vector{} : std::move(getResp.values[0]); -} - -int64_t ParallelTcpStore::add(const std::string &key, int64_t value) -{ - return IncreaseKey(key, value); -} + bool ParallelTcpStore::deleteKey(const std::string &key) + { + pta::StoreMessage request{ pta::MessageType::DELETE_KEY, key }; + pta::StoreMessage response; + std::lock_guard lockGuard{ clientMutex_ }; + int ret = -1; + if (proxy_) { + ret = proxy_->SyncCall(request, response); + } else { + ret = client_.SyncCall(request, response); + } + if (ret != 0) { + throw std::runtime_error{ std::string("delete key ").append(key).append(" failed.") }; + } -bool ParallelTcpStore::deleteKey(const std::string &key) -{ - pta::StoreMessage request{ pta::MessageType::DELETE_KEY, key }; - pta::StoreMessage response; - std::lock_guard lockGuard{ clientMutex_ }; - auto ret = client_.SyncCall(request, response); - if (ret != 0) { - throw std::runtime_error{ std::string("delete key ").append(key).append(" failed.") }; + return !response.values.empty() && !response.values[0].empty() && response.values[0][0] > 0U; } - return !response.values.empty() && !response.values[0].empty() && response.values[0][0] > 0U; -} + bool ParallelTcpStore::check(const std::vector &keys) + { + throw std::runtime_error("unsupported check operation."); + } -bool ParallelTcpStore::check(const std::vector &keys) -{ - throw std::runtime_error("unsupported check operation."); -} + int64_t ParallelTcpStore::getNumKeys() + { + pta::StoreMessage request{ pta::MessageType::GET_NUM_KEYS }; + pta::StoreMessage response; + std::lock_guard lockGuard{ clientMutex_ }; + int ret = -1; + if (proxy_) { + ret = proxy_->SyncCall(request, response); + } else { + ret = client_.SyncCall(request, response); + } + if (ret != 0) { + throw std::runtime_error{ "get number keys failed." }; + } -int64_t ParallelTcpStore::getNumKeys() -{ - pta::StoreMessage request{ pta::MessageType::GET_NUM_KEYS }; - pta::StoreMessage response; - std::lock_guard lockGuard{ clientMutex_ }; - auto ret = client_.SyncCall(request, response); - if (ret != 0) { - throw std::runtime_error{ "get number keys failed." }; + return pta::StoreMessagePacker::UnpackPod(response.values[0]); } - return pta::StoreMessagePacker::UnpackPod(response.values[0]); -} + void ParallelTcpStore::wait(const std::vector &keys) + { + wait(keys, timeout_); + } -void ParallelTcpStore::wait(const std::vector &keys) -{ - wait(keys, timeout_); -} + void ParallelTcpStore::wait(const std::vector &keys, const std::chrono::milliseconds &timeout) + { + pta::StoreMessage request{ pta::MessageType::WAIT, keys }; + pta::StoreMessage response; + client_.SetReceiveTimeout(timeout); + std::lock_guard lockGuard{ clientMutex_ }; + DoWait(request, response); + } -void ParallelTcpStore::wait(const std::vector &keys, const std::chrono::milliseconds &timeout) -{ - pta::StoreMessage request{ pta::MessageType::WAIT, keys }; - pta::StoreMessage response; - client_.SetReceiveTimeout(timeout); - std::lock_guard lockGuard{ clientMutex_ }; - DoWait(request, response); -} + const std::chrono::milliseconds &ParallelTcpStore::getTimeout() const noexcept + { + return timeout_; + } -const std::chrono::milliseconds &ParallelTcpStore::getTimeout() const noexcept -{ - return timeout_; -} + void ParallelTcpStore::setTimeout(const std::chrono::milliseconds &timeout) + { + timeout_ = timeout; + } -void ParallelTcpStore::setTimeout(const std::chrono::milliseconds &timeout) -{ - timeout_ = timeout; -} + int64_t ParallelTcpStore::IncreaseKey(const std::string &key, int64_t value) + { + pta::StoreMessage request{ pta::MessageType::ADD, key, pta::StoreMessagePacker::PackPod(value) }; + pta::StoreMessage response; + std::lock_guard lockGuard{ clientMutex_ }; + int ret = -1; + if (proxy_) { + ret = proxy_->SyncCall(request, response); + } else { + ret = client_.SyncCall(request, response); + } -int64_t ParallelTcpStore::IncreaseKey(const std::string &key, int64_t value) -{ - pta::StoreMessage request{ pta::MessageType::ADD, key, pta::StoreMessagePacker::PackPod(value) }; - pta::StoreMessage response; - std::lock_guard lockGuard{ clientMutex_ }; - int ret = -1; - if (proxy_) { - ret = proxy_->SyncCall(request, response); - } else { - ret = client_.SyncCall(request, response); - } + if (ret != 0) { + throw std::runtime_error{ std::string("add key ").append(key).append(" failed.") }; + } - if (ret != 0) { - throw std::runtime_error{ std::string("add key ").append(key).append(" failed.") }; + return pta::StoreMessagePacker::UnpackPod(response.values[0]); } - return pta::StoreMessagePacker::UnpackPod(response.values[0]); -} - -void ParallelTcpStore::DoWait(const pta::StoreMessage &req, pta::StoreMessage &res) -{ - auto ret = client_.SyncCall(req, res); - if (ret != 0) { - throw std::runtime_error{ "get number keys failed." }; + void ParallelTcpStore::DoWait(const pta::StoreMessage &req, pta::StoreMessage &res) + { + int ret = -1; + if (proxy_) { + ret = proxy_->SyncCall(request, response); + } else { + ret = client_.SyncCall(request, response); + } + if (ret != 0) { + throw std::runtime_error{ "get number keys failed." }; + } } -} -std::shared_ptr ParallelTcpStore::GetSharedServer(const std::string &initKey, uint16_t port, - c10::optional numWorkers) -{ - std::unique_lock lockGuard{ cacheServerMutex_ }; - auto pos = cachedServers_.find(port); - if (pos != cachedServers_.end()) { - auto server = pos->second.lock(); - if (server != nullptr) { - return server; + std::shared_ptr ParallelTcpStore::GetSharedServer(const std::string &initKey, + uint16_t port, c10::optional numWorkers) + { + std::unique_lock lockGuard{ cacheServerMutex_ }; + auto pos = cachedServers_.find(port); + if (pos != cachedServers_.end()) { + auto server = pos->second.lock(); + if (server != nullptr) { + return server; + } + + cachedServers_.erase(pos); } - cachedServers_.erase(pos); + auto server = std::make_shared(initKey, port, numWorkers); + cachedServers_.emplace(port, server); + return server; } - - auto server = std::make_shared(initKey, port, numWorkers); - cachedServers_.emplace(port, server); - return server; -} } // c10d \ No newline at end of file diff --git a/torch_npu/csrc/distributed/ParallelTcpStore.hpp b/torch_npu/csrc/distributed/ParallelTcpStore.hpp index 5a65d6d0ff..6301880eff 100644 --- a/torch_npu/csrc/distributed/ParallelTcpStore.hpp +++ b/torch_npu/csrc/distributed/ParallelTcpStore.hpp @@ -1,4 +1,4 @@ -/** +/* * * @copyright Copyright (c) 2024 Huawei Technologies Co., Ltd. All rights reserved. * * Licensed under the BSD 3-Clause License (the "License"); @@ -25,16 +25,15 @@ #include #include "c10d/TCPStore.hpp" -#include "Client.hpp" +#include "StoreClient.hpp" #include "ParallelTcpServer.hpp" namespace c10d { namespace pta { - class Proxy; -using CallBackFn = std::function; +using CallBackFn = std::function; class ParallelStoreServer { -public: +public: ParallelStoreServer(std::string initKey, uint16_t port, c10::optional numWorkers); ParallelStoreServer(std::string localSocketPath, CallBackFn callback); virtual ~ParallelStoreServer() noexcept; @@ -51,6 +50,7 @@ private: pta::StoreMessage ProcessGetNumKeyRequest(int fd, const pta::StoreMessage &request) noexcept; pta::StoreMessage ProcessWaitKeysRequest(int fd, const pta::StoreMessage &request) noexcept; void InitializeHandlers() noexcept; + void LocalInitializeHandlers() noexcept; bool CheckAllKeysExistInLock(const std::vector &keys) noexcept; private: @@ -72,7 +72,7 @@ private: class ParallelTcpStore : public Store { public: - explicit ParallelTcpStore(const std::string &host, const TCPStoreOptions &opts = {}); + explicit ParallelTcpStore(const std::string &host, const int16_t &localRank const TCPStoreOptions &opts = {}); ~ParallelTcpStore() noexcept override; public: @@ -95,8 +95,8 @@ private: static std::shared_ptr GetSharedServer(const std::string &initKey, uint16_t port, c10::optional numWorkers); -private: - pta::Client client_; +private: + pta::Client client_; std::unique_ptr proxy_; std::shared_ptr server_; std::mutex clientMutex_; diff --git a/torch_npu/csrc/distributed/Client.cpp b/torch_npu/csrc/distributed/StoreClient.cpp similarity index 90% rename from torch_npu/csrc/distributed/Client.cpp rename to torch_npu/csrc/distributed/StoreClient.cpp index f9a3709d27..934a563f6a 100644 --- a/torch_npu/csrc/distributed/Client.cpp +++ b/torch_npu/csrc/distributed/StoreClient.cpp @@ -1,4 +1,4 @@ -/** +/* * * @copyright Copyright (c) 2024 Huawei Technologies Co., Ltd. All rights reserved. * * Licensed under the BSD 3-Clause License (the "License"); @@ -25,15 +25,15 @@ #include #include "c10/util/Logging.h" -#include "Client.hpp" +#include "StoreClient.hpp" #include "StoreMessagePacker.hpp" namespace c10d { namespace pta { static constexpr uint32_t READ_BUF_SZ = 256; -Client::Client(const std::string host, uint16_t port) noexcept - : host_{ std::move(host) }, port_{ port }, socketFd_(-1) {} +Client::Client(const std::string host, uint16_t port) noexcept : host_{ std::move(host) }, port_{ port }, socketFd_(-1) +{} Client::Client() noexcept {} void Client::init(std::string localSocketPath) noexcept @@ -41,7 +41,7 @@ void Client::init(std::string localSocketPath) noexcept localSocketPath_ = localSocketPath; } -int Client::Connect() noexcept +int Client::Connect() noexcept { socketFd_ = socket(AF_INET, SOCK_STREAM, 0); if (socketFd_ < 0) { @@ -94,7 +94,7 @@ int Client::Close() noexcept return ret; } -int Client::LocalConnect() noexcept +int Client::LocalConnect() noexcept { socketFd_ = socket(AF_UNIX, SOCK_STREAM, 0); if (socketFd_ < 0) { @@ -103,18 +103,20 @@ int Client::LocalConnect() noexcept } struct sockaddr_un servAddr {}; servAddr.sun_family = AF_UNIX; - strncpy(servAddr.sun_path, localSocketPath_.c_str(), sizeof(servAddr.sun_path) - 1); + servAddr.sun_path[0] = '\0'; + strncpy(servAddr.sun_path + 1, localSocketPath_.c_str(), sizeof(servAddr.sun_path) - 1); int lastError = 0; auto endTime = std::chrono::steady_clock::now() + std::chrono::minutes(1); while (std::chrono::steady_clock::now() < endTime) { - auto ret = connect(socketFd_, reinterpret_cast(&servAddr), sizeof(servAddr)); + auto ret = connect(socketFd_, reinterpret_cast(&servAddr), sizeof(servAddr)); if (ret == 0) { return 0; } if (errno != lastError) { - LOG(ERROR) << "connect socket to local server(" << localSocketPath_ << ") failed " << errno << " : " << strerror(errno); + LOG(ERROR) << "connect socket to local server(" << localSocketPath_ << ") failed " << errno << " : " << + strerror(errno); lastError = errno; } @@ -129,10 +131,9 @@ int Client::LocalConnect() noexcept } return -1; - } -int Client::LocalClose() noexcept +int Client::LocalClose() noexcept { auto ret = close(socketFd_); if (ret == 0) { @@ -140,8 +141,7 @@ int Client::LocalClose() noexcept return 0; } - LOG(ERROR) << "close socket to server(" << localSocketPath_ << ") failed " << errno << " : " << - strerror(errno); + LOG(ERROR) << "close socket to server(" << localSocketPath_ << ") failed " << errno << " : " << strerror(errno); return ret; } diff --git a/torch_npu/csrc/distributed/Client.hpp b/torch_npu/csrc/distributed/StoreClient.hpp similarity index 91% rename from torch_npu/csrc/distributed/Client.hpp rename to torch_npu/csrc/distributed/StoreClient.hpp index 34b598e2fd..c7e7f263e3 100644 --- a/torch_npu/csrc/distributed/Client.hpp +++ b/torch_npu/csrc/distributed/StoreClient.hpp @@ -1,4 +1,4 @@ -/** +/* * * @copyright Copyright (c) 2024 Huawei Technologies Co., Ltd. All rights reserved. * * Licensed under the BSD 3-Clause License (the "License"); @@ -25,7 +25,7 @@ namespace c10d { namespace pta { class Client { public: - Client(const std::string host, uint16_t port) noexcept; // for tcp client + Client(const std::string host, uint16_t port) noexcept; // for tcp client Client() noexcept; void init(std::string localSocketPath) noexcept; int Connect() noexcept; @@ -38,9 +38,8 @@ public: private: std::string localSocketPath_{}; const std::string host_{}; - const uint16_t port_{0}; + const uint16_t port_{ 0 }; int socketFd_; }; } // pta } // c10d - diff --git a/torch_npu/distributed/rendezvous.py b/torch_npu/distributed/rendezvous.py index 20294541d8..b3293427e7 100644 --- a/torch_npu/distributed/rendezvous.py +++ b/torch_npu/distributed/rendezvous.py @@ -31,7 +31,7 @@ def _torchelastic_use_agent_store() -> bool: return os.environ.get("TORCH_NPU_ELASTIC_USE_AGENT_STORE", None) == str(True) -def _create_c10d_store(hostname, port, rank, world_size, timeout) -> Store: +def _create_c10d_store(hostname, port, rank, world_size, local_rank, timeout) -> Store: """ Smartly creates a c10d Store object on ``rank`` based on whether we need to re-use agent store. The TCPStore server is assumed to be hosted @@ -53,12 +53,12 @@ def _create_c10d_store(hostname, port, rank, world_size, timeout) -> Store: if _torchelastic_use_agent_store(): attempt = os.environ["TORCHELASTIC_RESTART_COUNT"] - tcp_store = ParallelStore(hostname, port, world_size, False, timeout) + tcp_store = ParallelStore(hostname, port, world_size, local_rank, False, timeout) return PrefixStore(f"/worker/attempt_{attempt}", tcp_store) else: start_daemon = rank == 0 return ParallelStore( - hostname, port, world_size, start_daemon, timeout, multi_tenant=True + hostname, port, world_size, local_rank, start_daemon, timeout, multi_tenant=True ) @@ -81,7 +81,8 @@ def _parallel_rendezvous_handler( rank = int(query["rank"]) world_size = int(query["world_size"]) - store = _create_c10d_store(result.hostname, result.port, rank, world_size, timeout) + local_rank = int(os.getenv('LOCAL_RANK', -1)) + store = _create_c10d_store(result.hostname, result.port, rank, world_size, local_rank, timeout) yield (store, rank, world_size) @@ -102,6 +103,7 @@ class ParallelTCPRendezvous(RendezvousHandler): master_port: int, rank: int, world_size: int, + local_rank: int, run_id: str, timeout: int, ): @@ -109,6 +111,7 @@ class ParallelTCPRendezvous(RendezvousHandler): self.master_port = master_port self.rank = rank self.world_size = world_size + self.local_rank = local_rank self.run_id = run_id self.timeout = timedelta(seconds=timeout) self._store: Optional[Store] = None @@ -124,6 +127,7 @@ class ParallelTCPRendezvous(RendezvousHandler): self.master_addr, self.master_port, self.world_size, + self.local_rank, is_master, self.timeout, multi_tenant=True, @@ -177,8 +181,9 @@ def _create_parallel_handler(params: RendezvousParameters) -> RendezvousHandler: timeout = _default_timeout_seconds os.environ.setdefault("TORCH_NPU_ELASTIC_USE_AGENT_STORE", str(True)) + local_rank = int(os.getenv('LOCAL_RANK', -1)) return ParallelTCPRendezvous( - master_addr, master_port, rank, world_size, run_id, timeout + master_addr, master_port, rank, world_size, local_rank, run_id, timeout ) -- Gitee From 47228e291acb88f211846c925b08a2ab82fd581c Mon Sep 17 00:00:00 2001 From: wuyangyu Date: Thu, 5 Sep 2024 10:36:59 +0800 Subject: [PATCH 14/14] debug --- .../csrc/distributed/ParallelTcpServer.hpp | 1 - .../csrc/distributed/ParallelTcpStore.cpp | 225 +++++++++--------- 2 files changed, 113 insertions(+), 113 deletions(-) diff --git a/torch_npu/csrc/distributed/ParallelTcpServer.hpp b/torch_npu/csrc/distributed/ParallelTcpServer.hpp index aa33b3dc06..c0969733d5 100644 --- a/torch_npu/csrc/distributed/ParallelTcpServer.hpp +++ b/torch_npu/csrc/distributed/ParallelTcpServer.hpp @@ -92,7 +92,6 @@ private: }; using ServerProcFn = std::function; -using CallBackFn = std::function; /* * * @brief epoll based TCP server with registered message processor. diff --git a/torch_npu/csrc/distributed/ParallelTcpStore.cpp b/torch_npu/csrc/distributed/ParallelTcpStore.cpp index fe889fae35..7f2d1af865 100644 --- a/torch_npu/csrc/distributed/ParallelTcpStore.cpp +++ b/torch_npu/csrc/distributed/ParallelTcpStore.cpp @@ -373,142 +373,143 @@ std::vector ParallelTcpStore::get(const std::string &key) ret = proxy_->SyncCall(waitReq, waitResp); if (ret != 0) { throw std::runtime_error{ std::string("proxy_ sync wait msg ").append(" failed.") }; - ret = proxy_->SyncCall(getReq, getResp); - if (ret != 0) { - throw std::runtime_error{ std::string("proxy_ sync get msg ").append(" failed.") }; - } - } else { - ret = client_.SyncCall(waitReq, waitResp); - if (ret != 0) { - throw std::runtime_error{ std::string("get key ").append(key).append(" failed.") }; - } - - ret = client_.SyncCall(getReq, getResp); - if (ret != 0) { - throw std::runtime_error{ std::string("get key ").append(key).append(" failed.") }; - } } - return getResp.values.empty() ? std::vector{} : std::move(getResp.values[0]); - } - - int64_t ParallelTcpStore::add(const std::string &key, int64_t value) - { - return IncreaseKey(key, value); - } - - bool ParallelTcpStore::deleteKey(const std::string &key) - { - pta::StoreMessage request{ pta::MessageType::DELETE_KEY, key }; - pta::StoreMessage response; - std::lock_guard lockGuard{ clientMutex_ }; - int ret = -1; - if (proxy_) { - ret = proxy_->SyncCall(request, response); - } else { - ret = client_.SyncCall(request, response); + ret = proxy_->SyncCall(getReq, getResp); + if (ret != 0) { + throw std::runtime_error{ std::string("proxy_ sync get msg ").append(" failed.") }; } + } else { + ret = client_.SyncCall(waitReq, waitResp); if (ret != 0) { - throw std::runtime_error{ std::string("delete key ").append(key).append(" failed.") }; + throw std::runtime_error{ std::string("get key ").append(key).append(" failed.") }; } - return !response.values.empty() && !response.values[0].empty() && response.values[0][0] > 0U; + ret = client_.SyncCall(getReq, getResp); + if (ret != 0) { + throw std::runtime_error{ std::string("get key ").append(key).append(" failed.") }; + } } + return getResp.values.empty() ? std::vector{} : std::move(getResp.values[0]); +} - bool ParallelTcpStore::check(const std::vector &keys) - { - throw std::runtime_error("unsupported check operation."); +int64_t ParallelTcpStore::add(const std::string &key, int64_t value) +{ + return IncreaseKey(key, value); +} + +bool ParallelTcpStore::deleteKey(const std::string &key) +{ + pta::StoreMessage request{ pta::MessageType::DELETE_KEY, key }; + pta::StoreMessage response; + std::lock_guard lockGuard{ clientMutex_ }; + int ret = -1; + if (proxy_) { + ret = proxy_->SyncCall(request, response); + } else { + ret = client_.SyncCall(request, response); + } + if (ret != 0) { + throw std::runtime_error{ std::string("delete key ").append(key).append(" failed.") }; } - int64_t ParallelTcpStore::getNumKeys() - { - pta::StoreMessage request{ pta::MessageType::GET_NUM_KEYS }; - pta::StoreMessage response; - std::lock_guard lockGuard{ clientMutex_ }; - int ret = -1; - if (proxy_) { - ret = proxy_->SyncCall(request, response); - } else { - ret = client_.SyncCall(request, response); - } - if (ret != 0) { - throw std::runtime_error{ "get number keys failed." }; - } + return !response.values.empty() && !response.values[0].empty() && response.values[0][0] > 0U; +} - return pta::StoreMessagePacker::UnpackPod(response.values[0]); - } +bool ParallelTcpStore::check(const std::vector &keys) +{ + throw std::runtime_error("unsupported check operation."); +} - void ParallelTcpStore::wait(const std::vector &keys) - { - wait(keys, timeout_); +int64_t ParallelTcpStore::getNumKeys() +{ + pta::StoreMessage request{ pta::MessageType::GET_NUM_KEYS }; + pta::StoreMessage response; + std::lock_guard lockGuard{ clientMutex_ }; + int ret = -1; + if (proxy_) { + ret = proxy_->SyncCall(request, response); + } else { + ret = client_.SyncCall(request, response); } - - void ParallelTcpStore::wait(const std::vector &keys, const std::chrono::milliseconds &timeout) - { - pta::StoreMessage request{ pta::MessageType::WAIT, keys }; - pta::StoreMessage response; - client_.SetReceiveTimeout(timeout); - std::lock_guard lockGuard{ clientMutex_ }; - DoWait(request, response); + if (ret != 0) { + throw std::runtime_error{ "get number keys failed." }; } - const std::chrono::milliseconds &ParallelTcpStore::getTimeout() const noexcept - { - return timeout_; - } + return pta::StoreMessagePacker::UnpackPod(response.values[0]); +} - void ParallelTcpStore::setTimeout(const std::chrono::milliseconds &timeout) - { - timeout_ = timeout; - } +void ParallelTcpStore::wait(const std::vector &keys) +{ + wait(keys, timeout_); +} - int64_t ParallelTcpStore::IncreaseKey(const std::string &key, int64_t value) - { - pta::StoreMessage request{ pta::MessageType::ADD, key, pta::StoreMessagePacker::PackPod(value) }; - pta::StoreMessage response; - std::lock_guard lockGuard{ clientMutex_ }; - int ret = -1; - if (proxy_) { - ret = proxy_->SyncCall(request, response); - } else { - ret = client_.SyncCall(request, response); - } +void ParallelTcpStore::wait(const std::vector &keys, const std::chrono::milliseconds &timeout) +{ + pta::StoreMessage request{ pta::MessageType::WAIT, keys }; + pta::StoreMessage response; + client_.SetReceiveTimeout(timeout); + std::lock_guard lockGuard{ clientMutex_ }; + DoWait(request, response); +} - if (ret != 0) { - throw std::runtime_error{ std::string("add key ").append(key).append(" failed.") }; - } +const std::chrono::milliseconds &ParallelTcpStore::getTimeout() const noexcept +{ + return timeout_; +} + +void ParallelTcpStore::setTimeout(const std::chrono::milliseconds &timeout) +{ + timeout_ = timeout; +} - return pta::StoreMessagePacker::UnpackPod(response.values[0]); +int64_t ParallelTcpStore::IncreaseKey(const std::string &key, int64_t value) +{ + pta::StoreMessage request{ pta::MessageType::ADD, key, pta::StoreMessagePacker::PackPod(value) }; + pta::StoreMessage response; + std::lock_guard lockGuard{ clientMutex_ }; + int ret = -1; + if (proxy_) { + ret = proxy_->SyncCall(request, response); + } else { + ret = client_.SyncCall(request, response); } - void ParallelTcpStore::DoWait(const pta::StoreMessage &req, pta::StoreMessage &res) - { - int ret = -1; - if (proxy_) { - ret = proxy_->SyncCall(request, response); - } else { - ret = client_.SyncCall(request, response); - } - if (ret != 0) { - throw std::runtime_error{ "get number keys failed." }; - } + if (ret != 0) { + throw std::runtime_error{ std::string("add key ").append(key).append(" failed.") }; } - std::shared_ptr ParallelTcpStore::GetSharedServer(const std::string &initKey, - uint16_t port, c10::optional numWorkers) - { - std::unique_lock lockGuard{ cacheServerMutex_ }; - auto pos = cachedServers_.find(port); - if (pos != cachedServers_.end()) { - auto server = pos->second.lock(); - if (server != nullptr) { - return server; - } + return pta::StoreMessagePacker::UnpackPod(response.values[0]); +} - cachedServers_.erase(pos); +void ParallelTcpStore::DoWait(const pta::StoreMessage &req, pta::StoreMessage &res) +{ + int ret = -1; + if (proxy_) { + ret = proxy_->SyncCall(req, res); + } else { + ret = client_.SyncCall(req, res); + } + if (ret != 0) { + throw std::runtime_error{ "get number keys failed." }; + } +} + +std::shared_ptr ParallelTcpStore::GetSharedServer(const std::string &initKey, + uint16_t port, c10::optional numWorkers) +{ + std::unique_lock lockGuard{ cacheServerMutex_ }; + auto pos = cachedServers_.find(port); + if (pos != cachedServers_.end()) { + auto server = pos->second.lock(); + if (server != nullptr) { + return server; } - auto server = std::make_shared(initKey, port, numWorkers); - cachedServers_.emplace(port, server); - return server; + cachedServers_.erase(pos); } + + auto server = std::make_shared(initKey, port, numWorkers); + cachedServers_.emplace(port, server); + return server; +} } // c10d \ No newline at end of file -- Gitee