diff --git a/torch_npu/csrc/distributed/Init.cpp b/torch_npu/csrc/distributed/Init.cpp index f4500b828b2648e4381e9278d4cffae71e05e6d2..a75767f0c480264440260a046c5af56c98252ee9 100644 --- a/torch_npu/csrc/distributed/Init.cpp +++ b/torch_npu/csrc/distributed/Init.cpp @@ -476,6 +476,7 @@ Example:: .def(py::init([](const std::string &host, uint16_t port, int worldSize, + int16_t localRank, bool isServer, std::chrono::milliseconds timeout, bool waitWorkers, @@ -486,8 +487,8 @@ Example:: } ::c10d::TCPStoreOptions opts{ port, isServer, numWorkers, waitWorkers, timeout, multiTenant }; - return c10::make_intrusive <::c10d::ParallelTcpStore>(host, opts); - }), py::arg("host") = "127.0.0.1", py::arg("port") = 29500, py::arg("world_size") = -1, + return c10::make_intrusive <::c10d::ParallelTcpStore>(host, localRank, opts); + }), py::arg("host") = "127.0.0.1", py::arg("port") = 29500, py::arg("world_size") = -1, py::arg("localRank") = -2, py::arg("is_server") = false, py::arg("timeout") = std::chrono::milliseconds(300000), py::arg("wait_workers") = true, py::arg("multi_tenant") = false); diff --git a/torch_npu/csrc/distributed/ParallelStoreProxy.cpp b/torch_npu/csrc/distributed/ParallelStoreProxy.cpp new file mode 100644 index 0000000000000000000000000000000000000000..647120de44116594cf99b7e49a50a7edb3625b2e --- /dev/null +++ b/torch_npu/csrc/distributed/ParallelStoreProxy.cpp @@ -0,0 +1,42 @@ +#include "ParallelStoreProxy.hpp" +#include "ParallelTcpStore.hpp" +#include "StoreClient.hpp" +#include "c10/util/Exception.h" +#include +#include "StoreMessagePacker.hpp" + +namespace c10d { +namespace pta { +Proxy::Proxy(const std::string localSocketPath, const std::string host, uint16_t port) + : localServer_{ std::make_unique(localSocketPath, + [this](const StoreMessage &msg) { return this->HandleLocalServerMessage(msg); }) }, + tcpClient_{ std::make_unique(host, port) } +{} + +void Proxy::Start() noexcept +{ + if (tcpClient_->Connect() != 0) { + throw std::runtime_error("Failed to connect to TCP server"); + } +} + +void Proxy::Stop() noexcept +{ + tcpClient_->Close(); +} + +int Proxy::SyncCall(const pta::StoreMessage &request, pta::StoreMessage &response) noexcept +{ + return tcpClient_->SyncCall(request, response); +} + +StoreMessage Proxy::HandleLocalServerMessage(const StoreMessage &message) noexcept +{ + StoreMessage response; + if (tcpClient_->SyncCall(message, response) != 0) { + throw std::runtime_error("Failed to sync call with TCP server"); + } + return response; +} +} // namespace pta +} // namespace c10d \ No newline at end of file diff --git a/torch_npu/csrc/distributed/ParallelStoreProxy.hpp b/torch_npu/csrc/distributed/ParallelStoreProxy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c1fa0a5b853e7afb2d41b7530e787953328c1558 --- /dev/null +++ b/torch_npu/csrc/distributed/ParallelStoreProxy.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include +#include +#include "StoreClient.hpp" +#include "StoreMessagePacker.hpp" + +namespace c10d { +namespace pta { +class ParallelStoreServer; + +class Proxy { +public: + Proxy(std::string localSocketPath, const std::string host, uint16_t port); + void Start() noexcept; + void Stop() noexcept; + int SyncCall(const pta::StoreMessage &request, pta::StoreMessage &response) noexcept; + StoreMessage HandleLocalServerMessage(const StoreMessage &message) noexcept; + +private: + std::unique_ptr localServer_; + std::unique_ptr tcpClient_; +}; +} // namespace pta +} // namespace c10d \ No newline at end of file diff --git a/torch_npu/csrc/distributed/ParallelTcpServer.cpp b/torch_npu/csrc/distributed/ParallelTcpServer.cpp index 1d91ab6b02e6d03c63d2dc523c3245aa6805b14f..8203c10a9b78ec18e2924d7bb479c50ea0eef6b4 100644 --- a/torch_npu/csrc/distributed/ParallelTcpServer.cpp +++ b/torch_npu/csrc/distributed/ParallelTcpServer.cpp @@ -1,4 +1,4 @@ -/** +/* * * @copyright Copyright (c) 2024 Huawei Technologies Co., Ltd. All rights reserved. * * Licensed under the BSD 3-Clause License (the "License"); @@ -15,6 +15,7 @@ */ #include #include +#include #include #include #include @@ -106,6 +107,13 @@ ParallelTcpServer::ParallelTcpServer(uint32_t threadNum, uint16_t port, ServerPr : threadNum_{ std::max(4U, threadNum) }, port_{ port }, process_{ std::move(process) } {} +ParallelTcpServer::ParallelTcpServer(uint32_t threadNum, const std::string localSocketPath, + ServerProcFn process) noexcept + : threadNum_{ std::max(4U, threadNum) }, localSocketPath_{ localSocketPath }, process_{ std::move(process) } +{ + isLocalServer_ = true; +} + int ParallelTcpServer::Start() noexcept { buffer_ = new (std::nothrow) uint8_t[4096]; @@ -114,7 +122,11 @@ int ParallelTcpServer::Start() noexcept return -1; } - listenSocket_ = CreateSocket(port_); + if (isLocalServer_) { + listenSocket_ = CreateLocalSocket(localSocketPath_); + } else { + listenSocket_ = CreateSocket(port_); + } if (listenSocket_ < 0) { delete[] buffer_; buffer_ = nullptr; @@ -196,8 +208,8 @@ void ParallelTcpServer::WakeupWaitingClients(const std::string &key) noexcept keyWaitingSockets_.erase(key); lockGuard.unlock(); - std::vector body{static_cast(MessageWaitKeyRes::KEYS_STOP_WAITING)}; - StoreMessage response{MessageType::WAIT, body}; + std::vector body{ static_cast(MessageWaitKeyRes::KEYS_STOP_WAITING) }; + StoreMessage response{ MessageType::WAIT, body }; auto buf = StoreMessagePacker::Pack(response); for (auto socket : stopWaitingSockets) { write(socket, buf.data(), buf.size()); @@ -239,6 +251,43 @@ int ParallelTcpServer::CreateSocket(uint16_t port) noexcept return sockFd; } +int ParallelTcpServer::CreateLocalSocket(const std::string &localSocketPath) noexcept +{ + struct sockaddr_un servAddr {}; + servAddr.sun_family = AF_UNIX; + servAddr.sun_path[0] = '\0'; + strncpy(servAddr.sun_path + 1, localSocketPath.c_str(), sizeof(servAddr.sun_path) - 1); + + auto sockFd = ::socket(AF_UNIX, SOCK_STREAM, 0); + if (sockFd < 0) { + LOG(ERROR) << "create local socket fd failed " << errno << " : " << strerror(errno); + return -1; + } + + unlink(localSocketPath.c_str()); // Remove any existing socket file + + auto ret = ::bind(sockFd, reinterpret_cast(&servAddr), sizeof(servAddr)); + if (ret != 0) { + LOG(ERROR) << "bind local socket fd failed " << errno << " : " << strerror(errno); + close(sockFd); + return -1; + } + + ret = listen(sockFd, MAX_EVENT_COUNT); + if (ret != 0) { + LOG(ERROR) << "listen local socket fd failed " << errno << " : " << strerror(errno); + close(sockFd); + return -1; + } + + if (SetNonBlocking(sockFd) != 0) { + close(sockFd); + return -1; + } + + return sockFd; +} + int ParallelTcpServer::CreateEpoll(int targetFd) noexcept { auto fd = epoll_create(1); diff --git a/torch_npu/csrc/distributed/ParallelTcpServer.hpp b/torch_npu/csrc/distributed/ParallelTcpServer.hpp index ad95e940822714269e7d6a734f306e553ed62c26..c0969733d54c4bd39191480c61a2d855d60a15ce 100644 --- a/torch_npu/csrc/distributed/ParallelTcpServer.hpp +++ b/torch_npu/csrc/distributed/ParallelTcpServer.hpp @@ -1,4 +1,4 @@ -/** +/* * * @copyright Copyright (c) 2024 Huawei Technologies Co., Ltd. All rights reserved. * * Licensed under the BSD 3-Clause License (the "License"); @@ -27,7 +27,7 @@ namespace c10d { namespace pta { -/** +/* * * @brief wrapper for pthread_spinlock_t */ class SpinLock { @@ -61,7 +61,7 @@ private: pthread_spinlock_t spinlock_{}; }; -/** +/* * * @brief store client IO context for server. */ class ClientIoContext { @@ -93,15 +93,16 @@ private: using ServerProcFn = std::function; -/** +/* * * @brief epoll based TCP server with registered message processor. */ class ParallelTcpServer { public: explicit ParallelTcpServer(uint32_t threadNum, uint16_t port, ServerProcFn process) noexcept; + explicit ParallelTcpServer(uint32_t threadNum, const std::string localSocketPath, ServerProcFn process) noexcept; int Start() noexcept; - + int LocalStart() noexcept; void Stop() noexcept; inline void SetKeysWaitingSocket(const std::vector &keys, int socket, int64_t waitCount) noexcept @@ -117,6 +118,7 @@ public: private: static int CreateSocket(uint16_t port) noexcept; + static int CreateLocalSocket(const std::string &localSocketPath) noexcept; static int CreateEpoll(int targetFd = -1) noexcept; @@ -131,11 +133,13 @@ private: static int SetNonBlocking(int fd) noexcept; private: - const uint32_t threadNum_; - const std::uint16_t port_; - const ServerProcFn process_; + const uint32_t threadNum_{ 0 }; + const std::uint16_t port_{ 0 }; + const std::string localSocketPath_{}; + const ServerProcFn process_{ nullptr }; int listenSocket_{ -1 }; int epCtlFd_{ -1 }; + bool isLocalServer_{ false }; std::thread ctlThread_; std::vector epClientFds_; std::vector clientThreads_; diff --git a/torch_npu/csrc/distributed/ParallelTcpStore.cpp b/torch_npu/csrc/distributed/ParallelTcpStore.cpp index d2acf0df9b8c8f953f6409b868c41a0653a2d719..7f2d1af86530e71452d9c2c2c3e65dc141dab026 100644 --- a/torch_npu/csrc/distributed/ParallelTcpStore.cpp +++ b/torch_npu/csrc/distributed/ParallelTcpStore.cpp @@ -1,4 +1,4 @@ -/** +/* * * @copyright Copyright (c) 2024 Huawei Technologies Co., Ltd. All rights reserved. * * Licensed under the BSD 3-Clause License (the "License"); @@ -15,6 +15,8 @@ */ #include "ParallelTcpServer.hpp" #include "ParallelTcpStore.hpp" +#include "ParallelStoreProxy.hpp" +#include "StoreClient.hpp" namespace c10d { namespace pta { @@ -42,6 +44,20 @@ ParallelStoreServer::ParallelStoreServer(std::string initKey, uint16_t port, c10 } } +ParallelStoreServer::ParallelStoreServer(std::string localSocketPath, CallBackFn callback) + : localSocketPath_(std::move(localSocketPath)), callback_(std::move(callback)) +{ + auto threadNum = 1U; + LocalInitializeHandlers(); + server_ = std::make_unique(threadNum, localSocketPath_, + [this](int fd, const pta::StoreMessage &request) { return ProcessRequest(fd, request); }); + if (server_->Start() != 0) { + throw std::runtime_error{ + std::string("start local server on socket ").append(localSocketPath_).append(" failed.") + }; + } +} + ParallelStoreServer::~ParallelStoreServer() noexcept { server_->Stop(); @@ -243,6 +259,26 @@ void ParallelStoreServer::InitializeHandlers() noexcept [this](int fd, const pta::StoreMessage &req) { return ProcessDeleteRequest(fd, req); }); } +void ParallelStoreServer::LocalInitializeHandlers() noexcept +{ + requestHandlers_.emplace(pta::MessageType::SET, + [this](int fd, const pta::StoreMessage &req) { return callback_(req); }); + requestHandlers_.emplace(pta::MessageType::COMPARE_SET, + [this](int fd, const pta::StoreMessage &req) { return callback_(req); }); + requestHandlers_.emplace(pta::MessageType::GET, + [this](int fd, const pta::StoreMessage &req) { return callback_(req); }); + requestHandlers_.emplace(pta::MessageType::ADD, + [this](int fd, const pta::StoreMessage &req) { return callback_(req); }); + requestHandlers_.emplace(pta::MessageType::CHECK, + [this](int fd, const pta::StoreMessage &req) { return callback_(req); }); + requestHandlers_.emplace(pta::MessageType::WAIT, + [this](int fd, const pta::StoreMessage &req) { return callback_(req); }); + requestHandlers_.emplace(pta::MessageType::GET_NUM_KEYS, + [this](int fd, const pta::StoreMessage &req) { return callback_(req); }); + requestHandlers_.emplace(pta::MessageType::DELETE_KEY, + [this](int fd, const pta::StoreMessage &req) { return callback_(req); }); +} + bool ParallelStoreServer::CheckAllKeysExistInLock(const std::vector &keys) noexcept { return std::all_of(keys.begin(), keys.end(), [this](const std::string &key) { return keyStore_.count(key) > 0; }); @@ -252,8 +288,8 @@ bool ParallelStoreServer::CheckAllKeysExistInLock(const std::vector std::mutex ParallelTcpStore::cacheServerMutex_; std::unordered_map> ParallelTcpStore::cachedServers_; -ParallelTcpStore::ParallelTcpStore(const std::string &host, const c10d::TCPStoreOptions &opts) - : Store(opts.timeout), client_{ host, opts.port } +ParallelTcpStore::ParallelTcpStore(const std::string &host, const int16_t &localRank, const c10d::TCPStoreOptions &opts) + : Store(opts.timeout) { if (opts.isServer) { if (opts.multiTenant) { @@ -263,14 +299,15 @@ ParallelTcpStore::ParallelTcpStore(const std::string &host, const c10d::TCPStore } } - if (client_.Connect() != 0) { - throw std::runtime_error{ std::string("connect tcp client to server(") - .append(host) - .append(":") - .append(std::to_string(opts.port)) - .append(" failed.") }; + if (localRank == -1) { + proxy_ = std::make_unique("/mnt/proxy/torch_dist_store", host, opts.port); + proxy_->Start(); + } else { + client_.init("/mnt/proxy/torch_dist_store"); + if (client_.LocalConnect() != 0) { + throw std::runtime_error{ std::string("connect local client to server failed.") }; + } } - if (opts.waitWorkers) { IncreaseKey(initKey_, 1); if (opts.isServer) { @@ -281,7 +318,11 @@ ParallelTcpStore::ParallelTcpStore(const std::string &host, const c10d::TCPStore ParallelTcpStore::~ParallelTcpStore() noexcept { - client_.Close(); + if (proxy_) { + proxy_->Stop(); + } else { + client_.LocalClose(); + } } void ParallelTcpStore::set(const std::string &key, const std::vector &value) @@ -289,7 +330,12 @@ void ParallelTcpStore::set(const std::string &key, const std::vector &v pta::StoreMessage request{ pta::MessageType::SET, key, value }; pta::StoreMessage response; std::lock_guard lockGuard{ clientMutex_ }; - auto ret = client_.SyncCall(request, response); + int ret = -1; + if (proxy_) { + ret = proxy_->SyncCall(request, response); + } else { + ret = client_.SyncCall(request, response); + } if (ret != 0) { throw std::runtime_error{ std::string("set key ").append(key).append(" failed.") }; } @@ -301,7 +347,12 @@ std::vector ParallelTcpStore::compareSet(const std::string &key, const pta::StoreMessage request{ pta::MessageType::COMPARE_SET, key, currentValue, newValue }; pta::StoreMessage response; std::lock_guard lockGuard{ clientMutex_ }; - auto ret = client_.SyncCall(request, response); + int ret = -1; + if (proxy_) { + ret = proxy_->SyncCall(request, response); + } else { + ret = client_.SyncCall(request, response); + } if (ret != 0) { throw std::runtime_error{ std::string("compare and set key ").append(key).append(" failed.") }; } @@ -317,16 +368,27 @@ std::vector ParallelTcpStore::get(const std::string &key) pta::StoreMessage getResp; std::lock_guard lockGuard{ clientMutex_ }; - auto ret = client_.SyncCall(waitReq, waitResp); - if (ret != 0) { - throw std::runtime_error{ std::string("get key ").append(key).append(" failed.") }; - } + int ret = -1; + if (proxy_) { + ret = proxy_->SyncCall(waitReq, waitResp); + if (ret != 0) { + throw std::runtime_error{ std::string("proxy_ sync wait msg ").append(" failed.") }; + } + ret = proxy_->SyncCall(getReq, getResp); + if (ret != 0) { + throw std::runtime_error{ std::string("proxy_ sync get msg ").append(" failed.") }; + } + } else { + ret = client_.SyncCall(waitReq, waitResp); + if (ret != 0) { + throw std::runtime_error{ std::string("get key ").append(key).append(" failed.") }; + } - ret = client_.SyncCall(getReq, getResp); - if (ret != 0) { - throw std::runtime_error{ std::string("get key ").append(key).append(" failed.") }; + ret = client_.SyncCall(getReq, getResp); + if (ret != 0) { + throw std::runtime_error{ std::string("get key ").append(key).append(" failed.") }; + } } - return getResp.values.empty() ? std::vector{} : std::move(getResp.values[0]); } @@ -340,7 +402,12 @@ bool ParallelTcpStore::deleteKey(const std::string &key) pta::StoreMessage request{ pta::MessageType::DELETE_KEY, key }; pta::StoreMessage response; std::lock_guard lockGuard{ clientMutex_ }; - auto ret = client_.SyncCall(request, response); + int ret = -1; + if (proxy_) { + ret = proxy_->SyncCall(request, response); + } else { + ret = client_.SyncCall(request, response); + } if (ret != 0) { throw std::runtime_error{ std::string("delete key ").append(key).append(" failed.") }; } @@ -358,7 +425,12 @@ int64_t ParallelTcpStore::getNumKeys() pta::StoreMessage request{ pta::MessageType::GET_NUM_KEYS }; pta::StoreMessage response; std::lock_guard lockGuard{ clientMutex_ }; - auto ret = client_.SyncCall(request, response); + int ret = -1; + if (proxy_) { + ret = proxy_->SyncCall(request, response); + } else { + ret = client_.SyncCall(request, response); + } if (ret != 0) { throw std::runtime_error{ "get number keys failed." }; } @@ -395,7 +467,13 @@ int64_t ParallelTcpStore::IncreaseKey(const std::string &key, int64_t value) pta::StoreMessage request{ pta::MessageType::ADD, key, pta::StoreMessagePacker::PackPod(value) }; pta::StoreMessage response; std::lock_guard lockGuard{ clientMutex_ }; - auto ret = client_.SyncCall(request, response); + int ret = -1; + if (proxy_) { + ret = proxy_->SyncCall(request, response); + } else { + ret = client_.SyncCall(request, response); + } + if (ret != 0) { throw std::runtime_error{ std::string("add key ").append(key).append(" failed.") }; } @@ -405,14 +483,19 @@ int64_t ParallelTcpStore::IncreaseKey(const std::string &key, int64_t value) void ParallelTcpStore::DoWait(const pta::StoreMessage &req, pta::StoreMessage &res) { - auto ret = client_.SyncCall(req, res); + int ret = -1; + if (proxy_) { + ret = proxy_->SyncCall(req, res); + } else { + ret = client_.SyncCall(req, res); + } if (ret != 0) { throw std::runtime_error{ "get number keys failed." }; } } -std::shared_ptr ParallelTcpStore::GetSharedServer(const std::string &initKey, uint16_t port, - c10::optional numWorkers) +std::shared_ptr ParallelTcpStore::GetSharedServer(const std::string &initKey, + uint16_t port, c10::optional numWorkers) { std::unique_lock lockGuard{ cacheServerMutex_ }; auto pos = cachedServers_.find(port); diff --git a/torch_npu/csrc/distributed/ParallelTcpStore.hpp b/torch_npu/csrc/distributed/ParallelTcpStore.hpp index 0841136eb557a362475547c42de2949f83812da5..6301880effffe7e8d697ded67d2ef2cec29e0d3a 100644 --- a/torch_npu/csrc/distributed/ParallelTcpStore.hpp +++ b/torch_npu/csrc/distributed/ParallelTcpStore.hpp @@ -1,4 +1,4 @@ -/** +/* * * @copyright Copyright (c) 2024 Huawei Technologies Co., Ltd. All rights reserved. * * Licensed under the BSD 3-Clause License (the "License"); @@ -25,13 +25,17 @@ #include #include "c10d/TCPStore.hpp" -#include "TcpClient.hpp" +#include "StoreClient.hpp" #include "ParallelTcpServer.hpp" namespace c10d { namespace pta { +class Proxy; + +using CallBackFn = std::function; class ParallelStoreServer { public: ParallelStoreServer(std::string initKey, uint16_t port, c10::optional numWorkers); + ParallelStoreServer(std::string localSocketPath, CallBackFn callback); virtual ~ParallelStoreServer() noexcept; void WaitWorkers(const std::chrono::milliseconds &timeout) noexcept; @@ -46,9 +50,12 @@ private: pta::StoreMessage ProcessGetNumKeyRequest(int fd, const pta::StoreMessage &request) noexcept; pta::StoreMessage ProcessWaitKeysRequest(int fd, const pta::StoreMessage &request) noexcept; void InitializeHandlers() noexcept; + void LocalInitializeHandlers() noexcept; bool CheckAllKeysExistInLock(const std::vector &keys) noexcept; private: + CallBackFn callback_; + std::string localSocketPath_; using RequestHandler = std::function; std::unique_ptr server_; std::unordered_map requestHandlers_; @@ -65,7 +72,7 @@ private: class ParallelTcpStore : public Store { public: - explicit ParallelTcpStore(const std::string &host, const TCPStoreOptions &opts = {}); + explicit ParallelTcpStore(const std::string &host, const int16_t &localRank const TCPStoreOptions &opts = {}); ~ParallelTcpStore() noexcept override; public: @@ -89,7 +96,8 @@ private: c10::optional numWorkers); private: - pta::TcpClient client_; + pta::Client client_; + std::unique_ptr proxy_; std::shared_ptr server_; std::mutex clientMutex_; std::condition_variable initWaitCond_; diff --git a/torch_npu/csrc/distributed/TcpClient.cpp b/torch_npu/csrc/distributed/StoreClient.cpp similarity index 67% rename from torch_npu/csrc/distributed/TcpClient.cpp rename to torch_npu/csrc/distributed/StoreClient.cpp index 44f3a95041fd654d58c6dd654dae39e1039ffa11..934a563f6a7fd3e1bfb56eee6792d671945f82b4 100644 --- a/torch_npu/csrc/distributed/TcpClient.cpp +++ b/torch_npu/csrc/distributed/StoreClient.cpp @@ -1,4 +1,4 @@ -/** +/* * * @copyright Copyright (c) 2024 Huawei Technologies Co., Ltd. All rights reserved. * * Licensed under the BSD 3-Clause License (the "License"); @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include #include @@ -24,17 +25,23 @@ #include #include "c10/util/Logging.h" -#include "TcpClient.hpp" +#include "StoreClient.hpp" +#include "StoreMessagePacker.hpp" namespace c10d { namespace pta { static constexpr uint32_t READ_BUF_SZ = 256; -TcpClient::TcpClient(std::string host, uint16_t port) noexcept - : host_{ std::move(host) }, port_{ port }, socketFd_{ -1 } +Client::Client(const std::string host, uint16_t port) noexcept : host_{ std::move(host) }, port_{ port }, socketFd_(-1) {} +Client::Client() noexcept {} -int TcpClient::Connect() noexcept +void Client::init(std::string localSocketPath) noexcept +{ + localSocketPath_ = localSocketPath; +} + +int Client::Connect() noexcept { socketFd_ = socket(AF_INET, SOCK_STREAM, 0); if (socketFd_ < 0) { @@ -74,7 +81,7 @@ int TcpClient::Connect() noexcept return -1; } -int TcpClient::Close() noexcept +int Client::Close() noexcept { auto ret = close(socketFd_); if (ret == 0) { @@ -87,7 +94,58 @@ int TcpClient::Close() noexcept return ret; } -int TcpClient::SyncCall(const StoreMessage &request, StoreMessage &response) noexcept +int Client::LocalConnect() noexcept +{ + socketFd_ = socket(AF_UNIX, SOCK_STREAM, 0); + if (socketFd_ < 0) { + LOG(ERROR) << "Create local socket failed: " << strerror(errno); + return -1; + } + struct sockaddr_un servAddr {}; + servAddr.sun_family = AF_UNIX; + servAddr.sun_path[0] = '\0'; + strncpy(servAddr.sun_path + 1, localSocketPath_.c_str(), sizeof(servAddr.sun_path) - 1); + + int lastError = 0; + auto endTime = std::chrono::steady_clock::now() + std::chrono::minutes(1); + while (std::chrono::steady_clock::now() < endTime) { + auto ret = connect(socketFd_, reinterpret_cast(&servAddr), sizeof(servAddr)); + if (ret == 0) { + return 0; + } + + if (errno != lastError) { + LOG(ERROR) << "connect socket to local server(" << localSocketPath_ << ") failed " << errno << " : " << + strerror(errno); + lastError = errno; + } + + if (errno == ETIMEDOUT) { + continue; + } + + if (errno == ECONNREFUSED) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + } + + return -1; +} + +int Client::LocalClose() noexcept +{ + auto ret = close(socketFd_); + if (ret == 0) { + socketFd_ = -1; + return 0; + } + + LOG(ERROR) << "close socket to server(" << localSocketPath_ << ") failed " << errno << " : " << strerror(errno); + return ret; +} + +int Client::SyncCall(const StoreMessage &request, StoreMessage &response) noexcept { auto packedRequest = StoreMessagePacker::Pack(request); auto ret = write(socketFd_, packedRequest.data(), packedRequest.size()); @@ -134,7 +192,7 @@ int TcpClient::SyncCall(const StoreMessage &request, StoreMessage &response) noe return result; } -int TcpClient::SetReceiveTimeout(const std::chrono::milliseconds &value) const noexcept +int Client::SetReceiveTimeout(const std::chrono::milliseconds &value) const noexcept { if (value == std::chrono::milliseconds::zero()) { return 0; diff --git a/torch_npu/csrc/distributed/TcpClient.hpp b/torch_npu/csrc/distributed/StoreClient.hpp similarity index 75% rename from torch_npu/csrc/distributed/TcpClient.hpp rename to torch_npu/csrc/distributed/StoreClient.hpp index 822ff5e24eade982c6bba72a932de775ca8cb3c6..c7e7f263e397d0bda9a36af2115c2719fcc9f2cf 100644 --- a/torch_npu/csrc/distributed/TcpClient.hpp +++ b/torch_npu/csrc/distributed/StoreClient.hpp @@ -1,4 +1,4 @@ -/** +/* * * @copyright Copyright (c) 2024 Huawei Technologies Co., Ltd. All rights reserved. * * Licensed under the BSD 3-Clause License (the "License"); @@ -23,19 +23,23 @@ namespace c10d { namespace pta { -class TcpClient { +class Client { public: - TcpClient(std::string host, uint16_t port) noexcept; + Client(const std::string host, uint16_t port) noexcept; // for tcp client + Client() noexcept; + void init(std::string localSocketPath) noexcept; int Connect() noexcept; int Close() noexcept; + int LocalConnect() noexcept; + int LocalClose() noexcept; int SyncCall(const StoreMessage &request, StoreMessage &response) noexcept; int SetReceiveTimeout(const std::chrono::milliseconds &value) const noexcept; private: - const std::string host_; - const uint16_t port_; + std::string localSocketPath_{}; + const std::string host_{}; + const uint16_t port_{ 0 }; int socketFd_; }; } // pta } // c10d - diff --git a/torch_npu/distributed/rendezvous.py b/torch_npu/distributed/rendezvous.py index 20294541d8bde18e25eb793d24c92b18bf8a59c2..b3293427e7d4c17b5d15d5f79d66aedb5612d626 100644 --- a/torch_npu/distributed/rendezvous.py +++ b/torch_npu/distributed/rendezvous.py @@ -31,7 +31,7 @@ def _torchelastic_use_agent_store() -> bool: return os.environ.get("TORCH_NPU_ELASTIC_USE_AGENT_STORE", None) == str(True) -def _create_c10d_store(hostname, port, rank, world_size, timeout) -> Store: +def _create_c10d_store(hostname, port, rank, world_size, local_rank, timeout) -> Store: """ Smartly creates a c10d Store object on ``rank`` based on whether we need to re-use agent store. The TCPStore server is assumed to be hosted @@ -53,12 +53,12 @@ def _create_c10d_store(hostname, port, rank, world_size, timeout) -> Store: if _torchelastic_use_agent_store(): attempt = os.environ["TORCHELASTIC_RESTART_COUNT"] - tcp_store = ParallelStore(hostname, port, world_size, False, timeout) + tcp_store = ParallelStore(hostname, port, world_size, local_rank, False, timeout) return PrefixStore(f"/worker/attempt_{attempt}", tcp_store) else: start_daemon = rank == 0 return ParallelStore( - hostname, port, world_size, start_daemon, timeout, multi_tenant=True + hostname, port, world_size, local_rank, start_daemon, timeout, multi_tenant=True ) @@ -81,7 +81,8 @@ def _parallel_rendezvous_handler( rank = int(query["rank"]) world_size = int(query["world_size"]) - store = _create_c10d_store(result.hostname, result.port, rank, world_size, timeout) + local_rank = int(os.getenv('LOCAL_RANK', -1)) + store = _create_c10d_store(result.hostname, result.port, rank, world_size, local_rank, timeout) yield (store, rank, world_size) @@ -102,6 +103,7 @@ class ParallelTCPRendezvous(RendezvousHandler): master_port: int, rank: int, world_size: int, + local_rank: int, run_id: str, timeout: int, ): @@ -109,6 +111,7 @@ class ParallelTCPRendezvous(RendezvousHandler): self.master_port = master_port self.rank = rank self.world_size = world_size + self.local_rank = local_rank self.run_id = run_id self.timeout = timedelta(seconds=timeout) self._store: Optional[Store] = None @@ -124,6 +127,7 @@ class ParallelTCPRendezvous(RendezvousHandler): self.master_addr, self.master_port, self.world_size, + self.local_rank, is_master, self.timeout, multi_tenant=True, @@ -177,8 +181,9 @@ def _create_parallel_handler(params: RendezvousParameters) -> RendezvousHandler: timeout = _default_timeout_seconds os.environ.setdefault("TORCH_NPU_ELASTIC_USE_AGENT_STORE", str(True)) + local_rank = int(os.getenv('LOCAL_RANK', -1)) return ParallelTCPRendezvous( - master_addr, master_port, rank, world_size, run_id, timeout + master_addr, master_port, rank, world_size, local_rank, run_id, timeout )