diff --git a/third_party/acl/libs/hccl.cpp b/third_party/acl/libs/hccl.cpp index ebf5b401f3731d35f86316505e7a83ca55d9991e..90ba1b71f5354f9c8eb8fb01bc8885d39e349706 100644 --- a/third_party/acl/libs/hccl.cpp +++ b/third_party/acl/libs/hccl.cpp @@ -34,4 +34,5 @@ hcclResult_t HcclBatchSendRecv(HcclSendRecvItemDef* sendRecvInfo, u32 itemNum, h aclrtStream stream) {return HCCL_SUCCESS;} hcclResult_t HcclCommInitAll(u32 ndev, s32 *devices, hcclComm_t *comms) {return HCCL_SUCCESS;} hcclResult_t HcclCommResume(hcclComm_t comm) {return HCCL_SUCCESS;} -hcclResult_t HcclSetGlobalCommInfo(u32 masterIp, u32 masterPort, u32 totalRankSize, u32 nodeID, u32 localRankSize){return HCCL_SUCCESS;} \ No newline at end of file +hcclResult_t HcclSetGlobalCommInfo(u32 masterIp, u32 masterPort, u32 totalRankSize, u32 nodeID, u32 localRankSize){return HCCL_SUCCESS;} +hcclResult_t HcclCommWorkingDevNicSet(HcclComm comm, u32 *ranks, bool *useBackup, u32 nRanks){return HCCL_SUCCESS;} \ No newline at end of file diff --git a/third_party/acl/libs/hccl.h b/third_party/acl/libs/hccl.h index 6c87438c2e1795ca04874bf348565908af6ac146..60a08307ee52c2e818240d29916672e8fc52d13a 100644 --- a/third_party/acl/libs/hccl.h +++ b/third_party/acl/libs/hccl.h @@ -109,4 +109,5 @@ hcclResult_t HcclBatchSendRecv(HcclSendRecvItemDef* sendRecvInfo, u32 itemNum, h hcclResult_t HcclCommInitAll(u32 ndev, s32 *devices, hcclComm_t *comms); hcclResult_t HcclCommResume(hcclComm_t comm); hcclResult_t HcclSetGlobalCommInfo(u32 masterIp, u32 masterPort, u32 totalRankSize, u32 nodeID, u32 localRankSize); +hcclResult_t HcclCommWorkingDevNicSet(HcclComm comm, u32 *ranks, bool *useBackup, u32 nRanks); } diff --git a/third_party/hccl/inc/hccl/hccl.h b/third_party/hccl/inc/hccl/hccl.h index 0401b4a60746c6affcd5c01387f30b8051e728f8..f41784cec07b2eccbd781c185011822bf8f070e8 100644 --- a/third_party/hccl/inc/hccl/hccl.h +++ b/third_party/hccl/inc/hccl/hccl.h @@ -185,6 +185,8 @@ extern HcclResult HcclCommResume(HcclComm comm); extern HcclResult HcclSetGlobalCommInfo(uint32_t masterIp, uint32_t masterPort, uint32_t totalRankSize, uint32_t nodeID, uint32_t localRankSize); +extern HcclResult HcclCommWorkingDevNicSet(HcclComm comm, uint32_t *ranks, bool *useBackup, uint32_t nRanks); + /** * @brief Initialize the comm configuration. * @param config Pointer to the comm configuration that needs to be initialized. diff --git a/torch_npu/csrc/distributed/HCCLUtils.hpp b/torch_npu/csrc/distributed/HCCLUtils.hpp index cbc5491735e44f7effc4027c55b99280b9b432a9..ad6420e5cc2189b185018fc0f614865e63e07968 100644 --- a/torch_npu/csrc/distributed/HCCLUtils.hpp +++ b/torch_npu/csrc/distributed/HCCLUtils.hpp @@ -64,6 +64,7 @@ extern HcclResult hcclCommInitClusterInfoConfig(const char *clusterInfo, uint32_ extern HcclResult hcclCreateSubCommConfig(HcclComm *comm, uint32_t rankNum, uint32_t *rankIds, uint64_t subCommId, uint32_t subCommRankId, HcclCommConfig* config, HcclComm *subComm); extern HcclResult hcclSetGlobalCommInfo(uint32_t masterIp, uint32_t masterPort, uint32_t totalRankSize, uint32_t nodeID, uint32_t localRankSize); +extern HcclResult hcclCommWorkingDevNicSet(HcclComm comm, uint32_t *ranks, bool *useBackup, uint32_t nRanks); // Provides additional detail into HCCL error codes based on when these are // thrown in the HCCL codebase. diff --git a/torch_npu/csrc/distributed/HcclCompile.h b/torch_npu/csrc/distributed/HcclCompile.h index c9027922a048f58e3ed9dd9fbaa519a240791aae..409a203b02b5cc7ff96a951742639b58b1954513 100644 --- a/torch_npu/csrc/distributed/HcclCompile.h +++ b/torch_npu/csrc/distributed/HcclCompile.h @@ -27,6 +27,7 @@ LOAD_FUNCTION(HcclGetCommConfigCapability) LOAD_FUNCTION(HcclCommInitClusterInfoConfig) LOAD_FUNCTION(HcclCreateSubCommConfig) LOAD_FUNCTION(HcclSetGlobalCommInfo) +LOAD_FUNCTION(HcclCommWorkingDevNicSet) extern HcclResult hcclAlltoAllV(const void *sendBuf, const void *sendCounts, const void *sdispls, @@ -281,4 +282,25 @@ HcclResult hcclSetGlobalCommInfo(uint32_t masterIp, uint32_t masterPort, uint32_ auto ret = func(masterIp, masterPort, totalRankSize, nodeID, localRankSize); return ret; } + +bool hcclCommWorkingDevNicSetExist() +{ + const static bool isHcclCommWorkingDevNicSetExist = []() -> bool { + auto func = GET_FUNC(HcclCommWorkingDevNicSet) + return func != nullptr; + }(); + return isHcclCommWorkingDevNicSetExist; +} + +HcclResult hcclCommWorkingDevNicSet(HcclComm comm, uint32_t *ranks, bool *useBackup, uint32_t nRanks) +{ + using HcclCommWorkingDevNicSetFunc = HcclResult(*)(HcclComm, uint32_t *, bool *, uint32_t); + static HcclCommWorkingDevNicSetFunc func = nullptr; + if (func == nullptr) { + func = (HcclCommWorkingDevNicSetFunc)GET_FUNC(HcclCommWorkingDevNicSet) + } + TORCH_CHECK(func, "Failed to find function ", "HcclCommWorkingDevNicSet", DIST_ERROR(ErrCode::NOT_FOUND)); + auto ret = func(comm, ranks, useBackup, nRanks); + return ret; +} } // namespace c10d_npu diff --git a/torch_npu/csrc/distributed/Init.cpp b/torch_npu/csrc/distributed/Init.cpp index 1df6943f625f933872bf6522fd38c9d8b767cc04..29cfafdef94131ef72a3685d9425c461671c88a7 100644 --- a/torch_npu/csrc/distributed/Init.cpp +++ b/torch_npu/csrc/distributed/Init.cpp @@ -396,6 +396,12 @@ PyObject* c10d_npu_init(PyObject* _unused, PyObject* noargs) .def("get_hccl_comm", &::c10d_npu::ProcessGroupHCCL::getHcclComm) .def("_set_hccl_comm_name", &::c10d_npu::ProcessGroupHCCL::setHcclCommName) .def("resume_hccl_comm", &::c10d_npu::ProcessGroupHCCL::resumeHcclComm) + .def("_get_switch_nic_comm", + &::c10d_npu::ProcessGroupHCCL::getSwitchNicComm, + py::arg("rankid"), + py::arg("nRanks"), + py::arg("ranks") = std::vector{}, + py::arg("useBackup") = std::vector{}) .def("abort_hccl_comm", &::c10d_npu::ProcessGroupHCCL::abortAndClearHcclComm) .def("_delete_tcpstore_key", &::c10d_npu::ProcessGroupHCCL::deleteTCPStoreKey) .def("set_watchdog_status", &::c10d_npu::ProcessGroupHCCL::setWatchdogStatus) diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index 5a2e407f2cb51eaacabaaf2358a943544dbee652..e691b82fb1c7b51f2f83cd43cb9883d0deeba1af 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -28,6 +28,7 @@ #include "third_party/acl/inc/acl/acl_base.h" #include "torch_npu/csrc/aten/CustomFunctions.h" #include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/core/npu/GetCANNInfo.h" #include "torch_npu/csrc/core/npu/NPUFunctions.h" #include "torch_npu/csrc/core/NPUBridge.h" #include "torch_npu/csrc/core/NPUStorageImpl.h" @@ -289,7 +290,13 @@ void getHcclCommConfig(HcclCommConfig* config, bool isP2P = false) } // Temporarily adding this logic to set deterministic states to avoid a known issues within HCCL. - config->hcclDeterministic = getDeterministicState() ? 1 : 0; + const std::string baseCannVersion = "8.2.RC1"; + const std::string baseCannModule = "CANN"; + if (IsGteCANNVersion(baseCannVersion, baseCannModule)) { + config->hcclDeterministic = 0xffffffff; + } else { + config->hcclDeterministic = getDeterministicState() ? 1 : 0; + } // Compatible with the size check of the old version of HCCL, forcibly convert // the config object to a size_t=32 object, and retain the N ± 2 version @@ -2787,6 +2794,65 @@ void ProcessGroupHCCL::resumeHcclComm(int device_id) ASCEND_LOGI("resumeHcclComm success, group id is %s.", options_->group_id.c_str()); } +bool ProcessGroupHCCL::getSwitchNicComm(int rankid, int nranks, std::vector& ranks, std::vector& useBackup) +{ + if (!hcclCommWorkingDevNicSetExist()) { + ASCEND_LOGI("The hcclCommWorkingDevNicSet does not exist. Skip it."); + return true; + } + if (options_->global_ranks_in_group.empty()) { + return true; + } + uint32_t sendnRank = 0; + std::vector sendRanks; + std::vector sendUseBackup; + at::Device device = getDeviceForRank(rankid); + std::vector devices = {device}; + auto key = getKeyFromDevices(devices); + HcclComm comm; + { + std::lock_guard lock(mutex_); + if (devHCCLCommMap_.find(key) != devHCCLCommMap_.end()) { + auto& hcclComms = devHCCLCommMap_[key]; + for (auto& hcclComm : hcclComms) { + comm = hcclComm->getHcclComm(); + } + } else { + return true; + } + } + HcclComm sendComm = comm; + for (int i = 0; i < nranks; i++) { + uint32_t localrank = 0; + for (uint32_t val : options_->global_ranks_in_group) { + if (ranks[i] == val) { + sendRanks.push_back(localrank); + sendUseBackup.push_back(useBackup[i]); + sendnRank++; + break; + } + localrank++; + } + } + if (sendnRank == 0) { + return true; + } + bool useBackupArr[sendUseBackup.size()]; + uint32_t sendRanksArr[sendRanks.size()]; + for (size_t i = 0; i < sendnRank; i++) { + useBackupArr[i] = sendUseBackup[i]; + sendRanksArr[i] = sendRanks[i]; + } + auto ret = hcclCommWorkingDevNicSet(sendComm, sendRanksArr, useBackupArr, sendnRank); + if (ret != HCCL_SUCCESS) { + ASCEND_LOGI("Fail to hcclCommWorkingDevNicSet"); + return false; + } else { + ASCEND_LOGI("Succeed to hcclCommWorkingDevNicSet"); + } + return true; +} + void ProcessGroupHCCL::setWatchdogStatus(int status) { watchdogStatus = WatchdogStatus(status); diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp index fe3315196c6841ab1b53f5cd50b18ae31ffb1f17..d8832af3a9dd637c59127a3548e6b6e3262a3c95 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp @@ -575,6 +575,8 @@ public: void resumeHcclComm(int device_id); + bool getSwitchNicComm(int rankid, int nRanks, std::vector& ranks, std::vector& useBackup); + void setWatchdogStatus(int status); void clearWorkMetaList(); diff --git a/torch_npu/distributed/distributed_c10d.py b/torch_npu/distributed/distributed_c10d.py index 6558856d503135894e059ec5cb9116d97a0ecf94..f43f416d08a5ef08e76e56c7ef324ed9ca4a334a 100644 --- a/torch_npu/distributed/distributed_c10d.py +++ b/torch_npu/distributed/distributed_c10d.py @@ -239,6 +239,20 @@ def reinit_process_group(group=None, rebuild_link=True): return group +def _comm_switch_nic(ranks, useBackup): + nRanks = len(ranks) + npu_device = torch.device('npu') + rankid = int(os.environ['RANK']) + result = True + for pg in _pg_map: + if (npu_device in pg._device_types): + presult = pg._get_backend(npu_device)._get_switch_nic_comm(rankid, nRanks, ranks, useBackup) + if not presult: + result = False + if not result: + return False + return True + def _reduce_scatter_tensor_uneven(output, input, input_split_sizes=None, op=dist.ReduceOp.SUM, group=None, async_op=False): if _rank_not_in_group(group): diff --git a/torch_npu/npu/__init__.py b/torch_npu/npu/__init__.py index b451ac72cf0a462a3759f2bc6f1c067dbdfee780..91d8ac4eb4bc9cf49291a7716daf4af2025505d4 100644 --- a/torch_npu/npu/__init__.py +++ b/torch_npu/npu/__init__.py @@ -484,6 +484,11 @@ def _lazy_new(cls, *args, **kwargs): return super(_NPUBase, cls).__new__(cls, *args, **kwargs) +def _comm_switch_nic(ranks, useBackup): + torch_npu.npu.synchronize() + return torch_npu.distributed.distributed_c10d._comm_switch_nic(ranks, useBackup) + + class _NPUBase: is_npu = True is_sparse = False