diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index 1d14cf06ef5ce218d7fa06037987b0776e693aca..24720d902dced30c19b45848933a97f4f9796db7 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -2749,7 +2749,7 @@ void ProcessGroupHCCL::resumeHcclComm(int device_id) { at::Device device = at::Device(c10::DeviceType::PrivateUse1, device_id); std::vector devices = {device}; - const auto key = getKeyFromDevices(devices); + auto key = getKeyFromDevices(devices); { std::lock_guard lock(mutex_); @@ -2761,6 +2761,17 @@ void ProcessGroupHCCL::resumeHcclComm(int device_id) HCCL_CHECK_ERROR(at_npu::hccl::HcclCommResumeFace(comm)); } } + if (hcclCommInitRootInfoConfigExist() && c10_npu::option::OptionsManager::GetP2PBufferSize() != 0) { + key = getKeySendRecv(rank_, getP2pPeer()); + if (devHCCLCommMap_.find(key) != devHCCLCommMap_.end()) { + // Reuse the cached communicator if there is one. + auto& hcclComms = devHCCLCommMap_[key]; + for (const auto& hcclComm : hcclComms) { + auto comm = hcclComm->getHcclComm(); + HCCL_CHECK_ERROR(at_npu::hccl::HcclCommResumeFace(comm)); + } + } + } } ASCEND_LOGI("resumeHcclComm success, group id is %s.", options_->group_id.c_str()); }