diff --git a/third_party/hccl/inc/hccl/hccl.h b/third_party/hccl/inc/hccl/hccl.h index 023914a348285ad17c459b077cdd03c4593637ea..216ef7a83847e424ee1b0679b351d188452a2981 100644 --- a/third_party/hccl/inc/hccl/hccl.h +++ b/third_party/hccl/inc/hccl/hccl.h @@ -212,6 +212,8 @@ inline void HcclCommConfigInit(HcclCommConfig *config) config->hcclRdmaTrafficClass = HCCL_COMM_TRAFFIC_CLASS_CONFIG_NOT_SET; config->hcclRdmaServiceLevel = HCCL_COMM_SERVICE_LEVEL_CONFIG_NOT_SET; config->hcclOpExpansionMode = HCCL_COMM_DEFAULT_OP_EXPANSION_MODE; + config->hcclWorldRankID = 0; + config->hcclJobID = 0; } /** diff --git a/third_party/hccl/inc/hccl/hccl_types.h b/third_party/hccl/inc/hccl/hccl_types.h index 40631676c1bdc9bb44256b083e647e99e8f6fc8f..9a02c61c0414a96af23bf2468ab96482512240fa 100644 --- a/third_party/hccl/inc/hccl/hccl_types.h +++ b/third_party/hccl/inc/hccl/hccl_types.h @@ -15,7 +15,7 @@ extern "C" { const uint32_t HCCL_COMM_CONFIG_INFO_BYTES = 24; const uint32_t HCCL_COMM_CONFIG_MAGIC_WORD = 0xf0f0f0f0; -const uint32_t HCCL_COMM_CONFIG_VERSION = 5; +const uint32_t HCCL_COMM_CONFIG_VERSION = 6; const uint32_t HCCL_COMM_DEFAULT_BUFFSIZE = 200; // 200MB buffer size const uint32_t HCCL_COMM_DEFAULT_DETERMINISTIC = 0; // Disable deterministic calculations const uint32_t COMM_NAME_MAX_LENGTH = 128; @@ -132,6 +132,8 @@ typedef struct HcclCommConfigDef { uint32_t hcclOpExpansionMode; uint32_t hcclRdmaTrafficClass; uint32_t hcclRdmaServiceLevel; + uint32_t hcclWorldRankID; + uint64_t hcclJobID; } HcclCommConfig; typedef enum { diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index 73ee79512160d1602bbb185bd88673894c4dcd97..7188bcf625bbbe11bce87510a88cd1986e756478 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -19,6 +19,10 @@ #include #include #include +#include +#include + +#include #include "op_plugin/OpInterface.h" #include "third_party/acl/inc/acl/acl.h" @@ -63,6 +67,7 @@ constexpr const char* P2P_DEVICE_KEY = "_p2p"; using hcclUs = std::chrono::steady_clock::time_point; constexpr int32_t MAX_GROUP_NAME_LEN = 128; +constexpr int32_t NSLB_JOBID_OFFSET = 32; // HCCL ReduceOp mapping std::map hcclOp = { @@ -949,6 +954,24 @@ ProcessGroupHCCL::ProcessGroupHCCL( PrefixStore *prefixStore = dynamic_cast(store_.get()); globalStore_ = prefixStore ? prefixStore->getUnderlyingNonPrefixStore() : store_; + c10::intrusive_ptr getTcpStore = store_; + while (getTcpStore) { + c10d::PrefixStore *asPrefixStore = dynamic_cast(getTcpStore.get()); + c10d::TCPStore *tcpStore = dynamic_cast(getTcpStore.get()); + if (tcpStore) { + if (!(tcpStore->getHost().empty())) { + tcpMasterAddr = tcpStore->getHost(); + tcpMasterPort = tcpStore->getPort(); + break; + } + } + if (asPrefixStore) { + getTcpStore = asPrefixStore->getUnderlyingStore(); + } else { + break; + } + } + const char* blockingWait = getenv(HCCL_BLOCKING_WAIT); try { if (blockingWait != nullptr) { @@ -2148,6 +2171,30 @@ std::vector>& ProcessGroupHCCL::getHCCLComm( return createHCCLComm(devicesKey, devices, commType, commConfig, p2pRank); } +void ProcessGroupHCCL::setNSLBCommConfig(HcclCommConfig** commConfig) +{ + const char* envPtr = std::getenv("RANK"); + if (envPtr == nullptr) { + ASCEND_LOGI("Failed to get env info for NSLB-DP."); + return; + } + uint32_t worldRankID = std::stoi(std::string(envPtr)); + options_->hccl_config["hccl_world_rank_id"] = worldRankID; + uint32_t masterPort = tcpMasterPort; + struct sockaddr_in sa; + std::string master_addr = tcpMasterAddr; + inet_pton(AF_INET, std::string(master_addr).c_str(), &(sa.sin_addr)); + uint32_t masterIp = ntohl(sa.sin_addr.s_addr); + uint64_t jobID = masterPort; + jobID = (jobID << NSLB_JOBID_OFFSET); + jobID += masterIp; + options_->hccl_config["hccl_job_id"] = jobID; + if ((*commConfig) != nullptr) { + (*commConfig)->hcclWorldRankID = worldRankID; + (*commConfig)->hcclJobID = jobID; + } +} + void ProcessGroupHCCL::createHCCLComm( const std::string& devicesKey, const std::vector& devices, @@ -2172,6 +2219,10 @@ void ProcessGroupHCCL::createHCCLComm( HcclCommConfig config; + if (options_->global_ranks_in_group.empty()) { + setNSLBCommConfig(&commConfig); + } + npuGuard.set_index(devices[i].index()); switch (commType) { case HcclCommType::DEFAULT: @@ -3093,6 +3144,22 @@ HcclCommConfig ProcessGroupHCCL::createHcclCommConfigWithOptions() } } + if (options_->hccl_config.find("hccl_world_rank_id") != options_->hccl_config.end()) { + if (std::holds_alternative(options_->hccl_config["hccl_world_rank_id"])) { + config.hcclOpExpansionMode = std::get(options_->hccl_config["hccl_world_rank_id"]); + } else { + TORCH_CHECK(false, "Value type of hccl_world_rank_id should be int.", DIST_ERROR(ErrCode::TYPE)); + } + } + + if (options_->hccl_config.find("hccl_job_id") != options_->hccl_config.end()) { + if (std::holds_alternative(options_->hccl_config["hccl_job_id"])) { + config.hcclOpExpansionMode = std::get(options_->hccl_config["hccl_job_id"]); + } else { + TORCH_CHECK(false, "Value type of hccl_job_id should be int.", DIST_ERROR(ErrCode::TYPE)); + } + } + return config; } diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp index 4021373b52b42290db011dc93094df4784e99842..e74714f732d9ba165718e78013c0150f0043b594 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp @@ -384,7 +384,7 @@ public: return c10::make_intrusive(_is_high_priority_stream); } - std::unordered_map> hccl_config; + std::unordered_map> hccl_config; std::chrono::milliseconds opTimeout; // Schedule HCCL operations on high priority CUDA streams @@ -571,6 +571,8 @@ public: void resumeHcclComm(int device_id); + void setNSLBCommConfig(HcclCommConfig** commConfig); + bool setCommWorkingDevNic( const HcclComm& comm, int nranks, @@ -953,6 +955,10 @@ protected: static std::string exceptionMessage_; + std::string tcpMasterAddr; + + uint32_t tcpMasterPort; + private: // Helper that encapsulates work shared across all collective communication // primitives.