diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index 8ffaf3901cd473aa0abf511c441199348d4a9b88..76807cbc4931201098261d8e4b8032bba0f2237a 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -1181,6 +1181,7 @@ void ProcessGroupHCCL::abortAndClearHcclComm(c10::optional abortRea abortCommsFromMap(devHCCLCommMap_, rank_, abortReason); devHCCLCommMap_.clear(); devHCCLCommNameMap_.clear(); + p2pSendRecvKeys_.clear(); hcclCommCounter_ = 0; return; } @@ -1223,6 +1224,7 @@ ProcessGroupHCCL::~ProcessGroupHCCL() } } devHCCLCommMap_.clear(); + p2pSendRecvKeys_.clear(); } ASCEND_LOGI("process group destroyed, group id is %s.", options_->group_id.c_str()); logger->info("process group destroyed, group id is %s.", options_->group_id.c_str()); @@ -2309,6 +2311,9 @@ bool ProcessGroupHCCL::createHCCLCommEx( return false; } hcclComms[i] = subComm; + if (commType == HcclCommType::P2P) { + hcclComms[i]->p2pPeer = getP2pPeer(); + } // Creates the HCCL streams streamVal.push_back(getNPUStreamByCurrentType(devices[i].index())); } @@ -2411,6 +2416,16 @@ std::vector>& ProcessGroupHCCL::createHCCLComm( // Move the HCCL resource to cache devHCCLCommMap_.emplace(devicesKey, std::move(hcclComms)); + if (commType == HcclCommType::P2P) { + int deviceId = -1; + NPU_CHECK_ERROR(c10_npu::GetDevice(&deviceId)); + auto iter = p2pSendRecvKeys_.find(deviceId); + if (iter == p2pSendRecvKeys_.end()) { + p2pSendRecvKeys_.emplace(deviceId, std::vector{devicesKey}); + } else { + iter->second.push_back(devicesKey); + } + } return devHCCLCommMap_[devicesKey]; } @@ -2767,14 +2782,16 @@ void ProcessGroupHCCL::resumeHcclComm(int device_id) HCCL_CHECK_ERROR(at_npu::hccl::HcclCommResumeFace(comm)); } } - if (hcclCommInitRootInfoConfigExist() && c10_npu::option::OptionsManager::GetP2PBufferSize() != 0) { - key = getKeySendRecv(rank_, getP2pPeer()); - if (devHCCLCommMap_.find(key) != devHCCLCommMap_.end()) { - // Reuse the cached communicator if there is one. - auto& hcclComms = devHCCLCommMap_[key]; - for (const auto& hcclComm : hcclComms) { - auto comm = hcclComm->getHcclComm(); - HCCL_CHECK_ERROR(at_npu::hccl::HcclCommResumeFace(comm)); + if (p2pSendRecvKeys_.find(device_id) != p2pSendRecvKeys_.end()) { + auto p2pKeys = p2pSendRecvKeys_[device_id]; + for (const auto& p2pKey : p2pKeys) { + if (devHCCLCommMap_.find(p2pKey) != devHCCLCommMap_.end()) { + // Reuse the cached communicator if there is one. + auto& hcclComms = devHCCLCommMap_[p2pKey]; + for (const auto& hcclComm : hcclComms) { + auto comm = hcclComm->getHcclComm(); + HCCL_CHECK_ERROR(at_npu::hccl::HcclCommResumeFace(comm)); + } } } } diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp index 9c2f365b3eb7a214df5a07f72cfb329e301432d1..583cd214919644be91329eb8204b22fac15fc75b 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp @@ -746,6 +746,8 @@ protected: // // Note that the order of the device for the tensor list matters. std::unordered_map>> devHCCLCommMap_; + + std::unordered_map> p2pSendRecvKeys_; std::unordered_map devHCCLCommNameMap_;