From 41d3ad89c0427f3e2b1070665be22bced489ba97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AB=E9=B9=8F=E5=85=A8?= Date: Wed, 10 Sep 2025 16:59:28 +0800 Subject: [PATCH] support ccl buff sharing --- third_party/hccl/inc/hccl/hccl.h | 1 + third_party/hccl/inc/hccl/hccl_types.h | 4 +- torch_npu/csrc/distributed/HCCLUtils.cpp | 38 ++++++++++++++++++ torch_npu/csrc/distributed/HCCLUtils.hpp | 5 +++ .../csrc/distributed/ProcessGroupHCCL.cpp | 40 +++++++++++++++++-- .../csrc/distributed/ProcessGroupHCCL.hpp | 2 + 6 files changed, 86 insertions(+), 4 deletions(-) diff --git a/third_party/hccl/inc/hccl/hccl.h b/third_party/hccl/inc/hccl/hccl.h index 216ef7a8384..3b5c5da87fb 100644 --- a/third_party/hccl/inc/hccl/hccl.h +++ b/third_party/hccl/inc/hccl/hccl.h @@ -214,6 +214,7 @@ inline void HcclCommConfigInit(HcclCommConfig *config) config->hcclOpExpansionMode = HCCL_COMM_DEFAULT_OP_EXPANSION_MODE; config->hcclWorldRankID = 0; config->hcclJobID = 0; + config->hcclBufferName[0] = '\0'; } /** diff --git a/third_party/hccl/inc/hccl/hccl_types.h b/third_party/hccl/inc/hccl/hccl_types.h index 9a02c61c041..7fb30f2ff20 100644 --- a/third_party/hccl/inc/hccl/hccl_types.h +++ b/third_party/hccl/inc/hccl/hccl_types.h @@ -15,7 +15,7 @@ extern "C" { const uint32_t HCCL_COMM_CONFIG_INFO_BYTES = 24; const uint32_t HCCL_COMM_CONFIG_MAGIC_WORD = 0xf0f0f0f0; -const uint32_t HCCL_COMM_CONFIG_VERSION = 6; +const uint32_t HCCL_COMM_CONFIG_VERSION = 7; const uint32_t HCCL_COMM_DEFAULT_BUFFSIZE = 200; // 200MB buffer size const uint32_t HCCL_COMM_DEFAULT_DETERMINISTIC = 0; // Disable deterministic calculations const uint32_t COMM_NAME_MAX_LENGTH = 128; @@ -23,6 +23,7 @@ const uint32_t UDI_MAX_LENGTH = 128; const uint32_t HCCL_COMM_TRAFFIC_CLASS_CONFIG_NOT_SET = 0xffffffff; const uint32_t HCCL_COMM_SERVICE_LEVEL_CONFIG_NOT_SET = 0xffffffff; const uint32_t HCCL_COMM_DEFAULT_OP_EXPANSION_MODE = 0; +const uint32_t BUFFER_NAME_MAX_LENGTH = 128; // cclbuffer name max length /** * @brief HCCL functions return value definition @@ -134,6 +135,7 @@ typedef struct HcclCommConfigDef { uint32_t hcclRdmaServiceLevel; uint32_t hcclWorldRankID; uint64_t hcclJobID; + char hcclBufferName[BUFFER_NAME_MAX_LENGTH]; } HcclCommConfig; typedef enum { diff --git a/torch_npu/csrc/distributed/HCCLUtils.cpp b/torch_npu/csrc/distributed/HCCLUtils.cpp index a21b2ef2cde..3555ec99906 100644 --- a/torch_npu/csrc/distributed/HCCLUtils.cpp +++ b/torch_npu/csrc/distributed/HCCLUtils.cpp @@ -272,4 +272,42 @@ void DebugInfoWriter::registerWriter(std::unique_ptr writer) std::unique_ptr DebugInfoWriter::writer_ = nullptr; std::atomic DebugInfoWriter::hasWriterRegistered_(false); + +struct HcclBufferNameKey { + c10::DeviceIndex device_index; + std::string name; + // 重载 < 运算符 + bool operator<(const HcclBufferNameKey& other) const { + if (device_index != other.device_index) { + return device_index < other.device_index; + } + return name < other.name; // 如果 id 相同,按 name 排序 + } +}; + +struct HcclBufferNameStreamMap { + std::map map; + std::mutex mutex; +} g_BufferNameStreamMap = {}; + +c10::optional getHcclStreamByBufferName(const std::string &name, c10::DeviceIndex device_index) +{ + std::unique_lock lock(g_BufferNameStreamMap.mutex); + auto &map = g_BufferNameStreamMap.map; + auto it = map.find({device_index, name}); + if (it == map.end()) { + return {}; + } + return it->second; +} + +bool setHcclStreamByBufferName(const std::string &name, c10::DeviceIndex device_index, c10_npu::NPUStream steam) +{ + HcclBufferNameKey key = {device_index, name}; + std::unique_lock lock(g_BufferNameStreamMap.mutex); + auto &map = g_BufferNameStreamMap.map; + auto pair = map.insert({key, steam}); + return pair.second; +} + } // namespace c10d_npu diff --git a/torch_npu/csrc/distributed/HCCLUtils.hpp b/torch_npu/csrc/distributed/HCCLUtils.hpp index 1033d8de97f..0dc8883c2c6 100644 --- a/torch_npu/csrc/distributed/HCCLUtils.hpp +++ b/torch_npu/csrc/distributed/HCCLUtils.hpp @@ -6,8 +6,10 @@ #include "torch_npu/csrc/core/npu/npu_log.h" #include "torch_npu/csrc/core/npu/sys_ctrl/npu_sys_ctrl.h" #include "torch_npu/csrc/core/npu/NPUException.h" +#include "torch_npu/csrc/core/npu/NPUStream.h" #include +#include #include #include "third_party/hccl/inc/hccl/hccl.h" #include "third_party/hccl/inc/hccl/hccl_types.h" @@ -162,4 +164,7 @@ private: static std::unique_ptr writer_; static std::atomic hasWriterRegistered_; }; + +c10::optional getHcclStreamByBufferName(const std::string &name, c10::DeviceIndex device_index); +bool setHcclStreamByBufferName(const std::string &name, c10::DeviceIndex device_index, c10_npu::NPUStream steam); } // namespace c10d_npu diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index 07907b465b0..7f689aa3deb 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -2203,6 +2203,26 @@ void ProcessGroupHCCL::setNSLBCommConfig(HcclCommConfig** commConfig) } } +c10_npu::NPUStream ProcessGroupHCCL::getHcclNPUStream(const at::Device &device) +{ + auto it = options_->hccl_config.find("hccl_buffer_name"); + if (it == options_->hccl_config.end()) { + return getNPUStreamByCurrentType(device.index()); + } + auto bufferName = std::get(it->second); + + auto stream = getHcclStreamByBufferName(bufferName, device.index()); + if (stream) { + ASCEND_LOGD("HCCL use the same steam with bufferName = %s, device_index = %d, stream id = %lu", bufferName.c_str(), device.index(), stream->id()); + return stream.value(); + } + + auto newStream = getNPUStreamByCurrentType(device.index()); + auto result = setHcclStreamByBufferName(bufferName, device.index(), newStream); + ASCEND_LOGD("HCCL use alloc new stream with bufferName = %s, device_index = %d, stream id = %lu. result = %d", bufferName.c_str(), device.index(), newStream.id(), result); + return newStream; +} + void ProcessGroupHCCL::createHCCLCommOrigin( const std::string& devicesKey, const std::vector& devices, @@ -2258,7 +2278,7 @@ void ProcessGroupHCCL::createHCCLCommOrigin( } // Creates the HCCL streams - streamVal.push_back(getNPUStreamByCurrentType(devices[i].index())); + streamVal.push_back(getHcclNPUStream(devices[i])); } auto endTime = std::chrono::steady_clock::now(); auto timeElapsed = std::chrono::duration_cast(endTime - startTime); @@ -2308,7 +2328,7 @@ bool ProcessGroupHCCL::createHCCLCommEx( } hcclComms[i] = comm; // Creates the HCCL streams - streamVal.push_back(getNPUStreamByCurrentType(devices[i].index())); + streamVal.push_back(getHcclNPUStream(devices[i])); } auto endTime = std::chrono::steady_clock::now(); auto timeElapsed = std::chrono::duration_cast(endTime - startTime); @@ -2383,7 +2403,7 @@ bool ProcessGroupHCCL::createHCCLCommEx( hcclComms[i]->p2pPeer = getP2pPeer(); } // Creates the HCCL streams - streamVal.push_back(getNPUStreamByCurrentType(devices[i].index())); + streamVal.push_back(getHcclNPUStream(devices[i])); } auto subEndTime = std::chrono::steady_clock::now(); auto subTimeElapsed = std::chrono::duration_cast(subEndTime - subStartTime); @@ -3217,6 +3237,20 @@ HcclCommConfig ProcessGroupHCCL::createHcclCommConfigWithOptions() } } + if (options_->hccl_config.find("hccl_buffer_name") != options_->hccl_config.end()) { + if (std::holds_alternative(options_->hccl_config["hccl_buffer_name"])) { + auto bufferName = std::get(options_->hccl_config["hccl_buffer_name"]); + uint32_t length = bufferName.length(); + if (length >= BUFFER_NAME_MAX_LENGTH) { + length = BUFFER_NAME_MAX_LENGTH - 1; + TORCH_NPU_WARN("The length of hccl_buffer_name has exceeded the limit BUFFER_NAME_MAX_LENGTH(128) which will be truncated to BUFFER_NAME_MAX_LENGTH - 1."); + } + strncpy(config.hcclBufferName, bufferName.c_str(), length); + config.hcclBufferName[length] = '\0'; + } else { + TORCH_CHECK(false, "Value type of hccl_buffer_name should be string.", DIST_ERROR(ErrCode::TYPE)); + } + } return config; } diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp index 2be4adcb3ed..ecfc606fe88 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp @@ -1089,6 +1089,8 @@ private: HcclCommConfig createHcclCommConfigWithOptions(); + c10_npu::NPUStream getHcclNPUStream(const at::Device &device); + static std::string getMstxHcclMsg(const std::string &opName, uint64_t dataCnt, HcclDataType hcclType, -- Gitee