diff --git a/test/distributed/test_api_get_hcom_name.py b/test/distributed/test_api_get_hcom_name.py new file mode 100644 index 0000000000000000000000000000000000000000..5005393b43a117962dcfc58c3229db52db6d7a20 --- /dev/null +++ b/test/distributed/test_api_get_hcom_name.py @@ -0,0 +1,54 @@ +import os + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.distributed.distributed_c10d import _world +import torch_npu +from torch_npu.testing.common_distributed import skipIfUnsupportMultiNPU +from torch_npu.testing.testcase import TestCase, run_tests + + +class GetHcclCommNameTest(TestCase): + @classmethod + def _init_dist_hccl(cls, rank, world_size): + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29500' + os.environ['HCCL_WHITELIST_DISABLE'] = '1' + torch_npu.npu.set_device(rank) + dist.init_process_group(backend='hccl', world_size=world_size, rank=rank) + return dist + + @classmethod + def _test_hccl_name(cls, rank, world_size, init_pg): + dist_group = init_pg(rank, world_size) + pg1 = torch.distributed.new_group() + assert pg1._get_backend(torch.device('npu')).get_hccl_comm_name(rank) != "" + pg2 = torch.distributed.new_group() + assert pg2._get_backend(torch.device('npu')).get_hccl_comm_name(rank, init_comm=False) == "" + assert pg2._get_backend(torch.device('npu')).get_hccl_comm_name(rank, init_comm=True) != "" + pg3 = torch.distributed.new_group() + assert pg3._get_backend(torch.device('npu')).get_hccl_comm_name(rank, init_comm=True) != "" + + def _test_multiprocess(self, f, init_pg, world_size): + ctx = mp.get_context('spawn') + ps = [] + for rank in range(world_size): + p = ctx.Process(target=f, args=(rank, world_size, init_pg)) + p.start() + ps.append(p) + + for p in ps: + p.join() + + @skipIfUnsupportMultiNPU(2) + def test_dist_get_hccl_name(self): + # CI currently supports only 2 devices + ranks = [2] + for world_size in ranks: + self._test_multiprocess(GetHcclCommNameTest._test_hccl_name, + GetHcclCommNameTest._init_dist_hccl, world_size) + + +if __name__ == '__main__': + run_tests() diff --git a/test/distributed/test_api_set_hcom_name.py b/test/distributed/test_api_set_hcom_name.py new file mode 100644 index 0000000000000000000000000000000000000000..d1ccddbe21060d80bab713504b4b807b5e4b9688 --- /dev/null +++ b/test/distributed/test_api_set_hcom_name.py @@ -0,0 +1,76 @@ +import os + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.distributed.distributed_c10d import _world +import torch_npu +from torch_npu.testing.common_distributed import skipIfUnsupportMultiNPU +from torch_npu.testing.testcase import TestCase, run_tests + + +class SetHcclCommNameTest(TestCase): + @classmethod + def _init_dist_hccl(cls, rank, world_size): + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29501' + os.environ['HCCL_WHITELIST_DISABLE'] = '1' + torch_npu.npu.set_device(rank) + dist.init_process_group(backend='hccl', world_size=world_size, rank=rank) + return dist + + @classmethod + def _test_set_hccl_name(cls, rank, world_size, init_pg): + dist_group = init_pg(rank, world_size) + + pg1 = torch.distributed.new_group() + isSupportHcclName = torch_npu.distributed._is_support_hccl_comm_name() + assert isSupportHcclName + pg1._get_backend(torch.device('npu'))._set_hccl_comm_name("test") + pg1._get_backend(torch.device('npu'))._set_hccl_comm_name("test") + pg_name = pg1._get_backend(torch.device('npu')).get_hccl_comm_name(rank) + assert pg_name == "test" + pg2 = torch.distributed.new_group() + pg_name = pg2._get_backend(torch.device('npu')).get_hccl_comm_name(rank) + pg_name_new = pg2._get_backend(torch.device('npu')).get_hccl_comm_name(rank) + assert pg_name == pg_name_new + + def _test_multiprocess(self, f, init_pg, world_size): + ctx = mp.get_context('spawn') + ps = [] + for rank in range(world_size): + p = ctx.Process(target=f, args=(rank, world_size, init_pg)) + p.start() + ps.append(p) + + for p in ps: + p.join() + + @skipIfUnsupportMultiNPU(2) + def test_dist_set_hccl_name(self): + # CI currently supports only 2 devices + ranks = [2] + for world_size in ranks: + self._test_multiprocess(SetHcclCommNameTest._test_set_hccl_name, + SetHcclCommNameTest._init_dist_hccl, world_size) + + def test_dist_set_hccl_name_case_failed(self): + dist_group = SetHcclCommNameTest._init_dist_hccl(0, 1) + pg1 = torch.distributed.new_group() + with self.assertRaises(RuntimeError): + pg1._get_backend(torch.device('npu'))._set_hccl_comm_name("") + with self.assertRaises(RuntimeError): + pg1._get_backend(torch.device('npu'))._set_hccl_comm_name( + "0123456789012345678901234567890123456789012345678901234567890123456789" + "0123456789012345678901234567890123456789012345678901234567") + with self.assertRaises(RuntimeError): + pg1._get_backend(torch.device('npu'))._set_hccl_comm_name("test") + pg1._get_backend(torch.device('npu'))._set_hccl_comm_name("test2") + with self.assertRaises(RuntimeError): + pg2 = torch.distributed.new_group() + pg2._get_backend(torch.device('npu')).get_hccl_comm_name(0) + pg2._get_backend(torch.device('npu'))._set_hccl_comm_name("test") + + +if __name__ == '__main__': + run_tests() diff --git a/third_party/hccl/inc/hccl/hccl.h b/third_party/hccl/inc/hccl/hccl.h index 936c3869a4a9a1195fd37377bb25a8e83e838b99..de9df9cdce696e624f641e8c06a42ac7a2c3db97 100644 --- a/third_party/hccl/inc/hccl/hccl.h +++ b/third_party/hccl/inc/hccl/hccl.h @@ -195,8 +195,14 @@ inline void HcclCommConfigInit(HcclCommConfig *config) config->hcclBufferSize = HCCL_COMM_DEFAULT_BUFFSIZE; config->hcclDeterministic = HCCL_COMM_DEFAULT_DETERMINISTIC; + config->hcclCommName[0] = '\0'; } +/** + * @brief Get a number that represents the capability of comm configuration. +*/ +extern uint32_t HcclGetCommConfigCapability(); + #ifdef __cplusplus } #endif // __cplusplus diff --git a/third_party/hccl/inc/hccl/hccl_types.h b/third_party/hccl/inc/hccl/hccl_types.h index bcbbd9ceb53c14f6137ed7cb855b474dc2e1fbc8..2fe100d7c7079f4bd694f28e8f80eae1973cdd40 100644 --- a/third_party/hccl/inc/hccl/hccl_types.h +++ b/third_party/hccl/inc/hccl/hccl_types.h @@ -15,9 +15,10 @@ extern "C" { const uint32_t HCCL_COMM_CONFIG_INFO_BYTES = 24; const uint32_t HCCL_COMM_CONFIG_MAGIC_WORD = 0xf0f0f0f0; -const uint32_t HCCL_COMM_CONFIG_VERSION = 1; +const uint32_t HCCL_COMM_CONFIG_VERSION = 2; const uint32_t HCCL_COMM_DEFAULT_BUFFSIZE = 200; // 200MB buffer size const uint32_t HCCL_COMM_DEFAULT_DETERMINISTIC = 0; // Disable deterministic calculations +const uint32_t COMM_NAME_MAX_LENGTH = 128; /** * @brief HCCL functions return value definition @@ -122,7 +123,15 @@ typedef struct HcclCommConfigDef { char reserved[HCCL_COMM_CONFIG_INFO_BYTES]; uint32_t hcclBufferSize; uint32_t hcclDeterministic; + char hcclCommName[COMM_NAME_MAX_LENGTH]; } HcclCommConfig; + +typedef enum { + HCCL_COMM_CONFIG_BUFFER_SIZE = 0, + HCCL_COMM_CONFIG_DETERMINISTIC = 1, + HCCL_COMM_CONFIG_COMM_NAME = 2, + HCCL_COMM_CONFIG_RESERVED, +} HcclCommConfigCapability; #ifdef __cplusplus } #endif // __cplusplus diff --git a/torch_npu/csrc/core/npu/interface/HcclInterface.cpp b/torch_npu/csrc/core/npu/interface/HcclInterface.cpp index 29b65970f377aada4d0fb312d066b76cd5362051..ee62c882cc59932b85d200a27e2c16e66f35683a 100644 --- a/torch_npu/csrc/core/npu/interface/HcclInterface.cpp +++ b/torch_npu/csrc/core/npu/interface/HcclInterface.cpp @@ -13,6 +13,7 @@ namespace hccl { REGISTER_LIBRARY(libhccl) LOAD_FUNCTION(HcclGetCommName) +LOAD_FUNCTION(HcclCommResume) extern HcclResult HcclGetCommNameFace(HcclComm commHandle, char* commName) { typedef HcclResult (*HcclGetCommNameFace)(HcclComm commHandle, char* commName); @@ -25,5 +26,28 @@ extern HcclResult HcclGetCommNameFace(HcclComm commHandle, char* commName) { " maybe you cann version is too low, please upgrade it", DIST_ERROR(ErrCode::NOT_FOUND)); return func(commHandle, commName); } + +extern HcclResult HcclCommResumeFace(HcclComm comm) +{ + typedef HcclResult (*HcclCommResumeFace)(HcclComm comm); + static HcclCommResumeFace func = nullptr; + if (func == nullptr) { + func = (HcclCommResumeFace)GET_FUNC(HcclCommResume); + } + TORCH_CHECK(func, "Failed to find function HcclCommResume," + " maybe you cann version is too low, please upgrade it", DIST_ERROR(ErrCode::NOT_FOUND)); + return func(comm); +} + +extern bool isHcclFeatureSupported(HcclCommConfigCapability configParameter) +{ + typedef uint32_t(*HcclGetCommConfigCapabilityFunc)(); + static HcclGetCommConfigCapabilityFunc func = (HcclGetCommConfigCapabilityFunc) GET_FUNC( + HcclGetCommConfigCapability); + if (func == nullptr) { + return false; + } + return configParameter < func(); +} } // namespace native } // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/core/npu/interface/HcclInterface.h b/torch_npu/csrc/core/npu/interface/HcclInterface.h index 67a1524f96f5c2afca5ba7102bf893c4567e9ecf..df81cd23d4fdf2e6eb22fdf2c41de985e21ffb52 100644 --- a/torch_npu/csrc/core/npu/interface/HcclInterface.h +++ b/torch_npu/csrc/core/npu/interface/HcclInterface.h @@ -16,5 +16,17 @@ namespace hccl { extern HcclResult HcclGetCommNameFace(HcclComm commHandle, char* commName); +/** + * @ingroup AscendCL + * @brief checkout hccl config Feature Supported + * + * @param configParameter [IN] config Feature enum + * @param bool [OUT] feature supported status + * + * @retval ACL_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +extern bool isHcclFeatureSupported(HcclCommConfigCapability configParameter); + } // namespace native } // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/distributed/HCCLUtils.cpp b/torch_npu/csrc/distributed/HCCLUtils.cpp index 9ddc5ecb5ac9d6c709d90abeac4b1ebca34c346b..e80a2cebd049a63e31c6d7bb2725ff20c2a9f294 100644 --- a/torch_npu/csrc/distributed/HCCLUtils.cpp +++ b/torch_npu/csrc/distributed/HCCLUtils.cpp @@ -1,4 +1,5 @@ #include "torch_npu/csrc/distributed/HCCLUtils.hpp" +#include "torch_npu/csrc/core/npu/interface/HcclInterface.h" namespace c10d_npu { @@ -85,4 +86,9 @@ std::string getHcclDataTypeSerialString(HcclDataType type) } } +bool isSupportHcclCommName() +{ + return at_npu::hccl::isHcclFeatureSupported(HcclCommConfigCapability::HCCL_COMM_CONFIG_COMM_NAME); +} + } diff --git a/torch_npu/csrc/distributed/HCCLUtils.hpp b/torch_npu/csrc/distributed/HCCLUtils.hpp index 755de3cd6a6b2a5fd774bb896cc9d6a433e7a936..322e40a575fce41eeed429e699659e3f875fd762 100644 --- a/torch_npu/csrc/distributed/HCCLUtils.hpp +++ b/torch_npu/csrc/distributed/HCCLUtils.hpp @@ -61,6 +61,8 @@ std::string getHcclDataTypeSerialString(HcclDataType type); bool isFilePathValid(const std::string& path); +bool isSupportHcclCommName(); + // RAII wrapper for HCCL communicator class HCCLComm { public: diff --git a/torch_npu/csrc/distributed/HcclCompile.h b/torch_npu/csrc/distributed/HcclCompile.h index ffe31842a751b8f44e4e7c589878ab0642c02358..3952d678b29385b8790dbd14904f00e1c914a442 100644 --- a/torch_npu/csrc/distributed/HcclCompile.h +++ b/torch_npu/csrc/distributed/HcclCompile.h @@ -18,6 +18,7 @@ LOAD_FUNCTION(HcclScatter) LOAD_FUNCTION(HcclBatchSendRecv) LOAD_FUNCTION(HcclAlltoAll) LOAD_FUNCTION(HcclCommInitRootInfoConfig) +LOAD_FUNCTION(HcclGetCommConfigCapability) extern HcclResult hcclAlltoAllV(const void *sendBuf, const void *sendCounts, const void *sdispls, HcclDataType sendType, const void *recvBuf, const void *recvCounts, const void *rdispls, @@ -129,4 +130,15 @@ HcclResult hcclCommInitRootInfoConfig(uint32_t nRanks, const HcclRootInfo *rootI auto ret = func(nRanks, rootInfo, rank, config, comm); return ret; } + +bool isHcclFeatureSupported(HcclCommConfigCapability configParameter) +{ + typedef uint32_t(*HcclGetCommConfigCapabilityFunc)(); + static HcclGetCommConfigCapabilityFunc func = (HcclGetCommConfigCapabilityFunc) GET_FUNC( + HcclGetCommConfigCapability); + if (func == nullptr) { + return false; + } + return configParameter < func(); +} } // namespace c10d_npu diff --git a/torch_npu/csrc/distributed/Init.cpp b/torch_npu/csrc/distributed/Init.cpp index 790034cce7ceef915c5c635bffc45dc6ce180ab7..10ca0f2334a2c0a385dd3789c11c90e3c0cb0934 100644 --- a/torch_npu/csrc/distributed/Init.cpp +++ b/torch_npu/csrc/distributed/Init.cpp @@ -273,6 +273,8 @@ PyObject* c10d_npu_init(PyObject* _unused, PyObject* noargs) { py::arg("src") = 0, py::call_guard()); + module.def("_is_support_hccl_comm_name", &c10d_npu::isSupportHcclCommName); + shared_ptr_class_(module, "Reducer") .def(py::init< std::vector, @@ -384,7 +386,18 @@ PyObject* c10d_npu_init(PyObject* _unused, PyObject* noargs) { py::arg("timeout") = kProcessGroupDefaultTimeout, py::call_guard()) .def("get_hccl_comm", &::c10d_npu::ProcessGroupHCCL::getHcclComm) - .def("get_hccl_comm_name", &::c10d_npu::ProcessGroupHCCL::getHcclCommName) + .def("_set_hccl_comm_name", &::c10d_npu::ProcessGroupHCCL::setHcclCommName) + .def("get_hccl_comm_name", + [](::c10d_npu::ProcessGroupHCCL &pg, py::args args, py::kwargs kwargs) + -> std::string { + int rankid = py::cast(args[0]); + bool init_comm = true; + if (kwargs.contains("init_comm")) { + init_comm = py::cast(kwargs["init_comm"]); + } + return pg.getHcclCommName(rankid, init_comm); + }, + py::call_guard()) .def("_get_stream_id", &::c10d_npu::ProcessGroupHCCL::getStreamId, py::arg("p2p") = false) .def_property_readonly("options", &::c10d_npu::ProcessGroupHCCL::getOptions) diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index 020ecefd22700b96692543db290d6b0607e8e88a..fe61ce35e43a40e78bc9aecdb52b609d758147d0 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -29,6 +29,7 @@ #include "torch_npu/csrc/distributed/HCCLUtils.hpp" #include "torch_npu/csrc/distributed/HcclCompile.h" #include "torch_npu/csrc/distributed/ProcessGroupHCCL.hpp" +#include "torch_npu/csrc/toolkit/profiler/common/utils.h" #include "torch_npu/csrc/framework/FormatHelper.h" #include "torch_npu/csrc/framework/utils/OpPreparation.h" @@ -37,6 +38,7 @@ namespace { static constexpr uint32_t kOpWaitTimeoutOffset = 30U; // second static uint32_t kOpWaitTimeout = 1868U; // second constexpr const char* P2P_DEVICE_KEY = "_p2p"; +constexpr const int32_t HCCL_CONFIG_SIZE = 32; using hcclUs = std::chrono::steady_clock::time_point; #define DURATION_US(x) (std::chrono::duration_cast(x)) @@ -211,6 +213,21 @@ void getP2PHcclCommCofig(HcclCommConfig* config) { HcclCommConfigInit(config); config->hcclBufferSize = c10_npu::option::OptionsManager::GetP2PBufferSize(); + // Compatible with the size check of the old version of HCCL, forcibly convert + // the config object to a size_t=32 object, and retain the N ± 2 version + if (!isHcclFeatureSupported(HcclCommConfigCapability::HCCL_COMM_CONFIG_COMM_NAME)) { + size_t *configSize = reinterpret_cast(config); + *configSize = HCCL_CONFIG_SIZE; + } +} + +void checkHcclCommConfigValid(const HcclCommConfig* config) +{ + if (strlen(config->hcclCommName) > 0) { + TORCH_CHECK(isHcclFeatureSupported(HcclCommConfigCapability::HCCL_COMM_CONFIG_COMM_NAME), + "The current version of CANN does not support the hcclCommName:", config->hcclCommName, + DIST_ERROR(ErrCode::NOT_SUPPORT)); + } } } // namespace @@ -1013,7 +1030,8 @@ void ProcessGroupHCCL::recordComm(std::string filename, std::string opName, cons std::vector>& ProcessGroupHCCL::getHCCLComm( const std::string& devicesKey, const std::vector& devices, - HcclCommType commType) + HcclCommType commType, + HcclCommConfig* commConfig) { // Sanity check if (devicesKey.empty()) { @@ -1055,7 +1073,12 @@ std::vector>& ProcessGroupHCCL::getHCCLComm( npuGuard.set_index(devices[i].index()); switch (commType) { case HcclCommType::DEFAULT: - hcclComms[i] = HCCLComm::create(numRanks, rank, hcclID); + if (commConfig != nullptr) { + checkHcclCommConfigValid(commConfig); + hcclComms[i] = HCCLComm::create_config(numRanks, rank, hcclID, commConfig); + } else { + hcclComms[i] = HCCLComm::create(numRanks, rank, hcclID); + } break; case HcclCommType::P2P: HcclCommConfig config; @@ -1313,7 +1336,26 @@ int64_t ProcessGroupHCCL::getHcclComm(int rankid) return hccl_comm; } -std::string ProcessGroupHCCL::getHcclCommName(int rankid) { +void ProcessGroupHCCL::setHcclCommName(const std::string& hccl_comm_name) +{ + auto nameSize = hccl_comm_name.size(); + TORCH_CHECK(nameSize > 0 && nameSize < COMM_NAME_MAX_LENGTH, + "The length of the name must be between 1 and ", COMM_NAME_MAX_LENGTH - 1, ", Invalid hcclCommName:", + hccl_comm_name, DIST_ERROR(ErrCode::VALUE)); + c10::DeviceIndex indexFromCurDevice = c10_npu::current_device(); + at::Device device = at::Device(c10::DeviceType::PrivateUse1, indexFromCurDevice); + std::vector devices = {device}; + const auto key = getKeyFromDevices(devices); + std::lock_guard lock(mutex_); + auto hcclCommNameIter = devHCCLCommNameMap_.emplace(key, hccl_comm_name); + auto currentHcclCommName = hcclCommNameIter.first->second; + TORCH_CHECK(currentHcclCommName == hccl_comm_name, + "The current ProcessGroup has already set the name and cannot be duplicated, Invalid hcclCommName:", + hccl_comm_name, ", current hcclCommName:", currentHcclCommName, DIST_ERROR(ErrCode::VALUE)); +} + +std::string ProcessGroupHCCL::getHcclCommName(int rankid, bool init_comm) +{ TORCH_CHECK(rankid >= 0, "Invalid rank ", rankid, DIST_ERROR(ErrCode::VALUE)); auto numNPUs = c10_npu::device_count(); TORCH_CHECK(numNPUs > 0, "Invalid device number", numNPUs, DIST_ERROR(ErrCode::VALUE)); @@ -1327,11 +1369,30 @@ std::string ProcessGroupHCCL::getHcclCommName(int rankid) { "If it's incorrect, it might have introduced an error."; TORCH_WARN_ONCE(warning_message); } - at::Device device = at::Device(c10::DeviceType::PrivateUse1, indexFromCurDevice); std::vector devices = {device}; const auto key = getKeyFromDevices(devices); - auto& hcclComms = getHCCLComm(key, devices); + if (!init_comm) { + std::lock_guard lock(mutex_); + if (devHCCLCommMap_.find(key) == devHCCLCommMap_.end()) { + return ""; + } + } + std::string hcclCommName = ""; + std::vector > hcclComms; + { + std::lock_guard lock(mutex_); + hcclCommName = devHCCLCommNameMap_[key]; + } + if (!hcclCommName.empty()) { + HcclCommConfig config; + HcclCommConfigInit(&config); + torch_npu::toolkit::profiler::Utils::safe_strcpy_s(config.hcclCommName, hcclCommName.c_str(), + COMM_NAME_MAX_LENGTH); + hcclComms = getHCCLComm(key, devices, HcclCommType::DEFAULT, &config); + } else { + hcclComms = getHCCLComm(key, devices); + } TORCH_CHECK(hcclComms.size() == 1, "expect hcclComms.size() = 1, but hcclComms.size() = ", hcclComms.size(), DIST_ERROR(ErrCode::VALUE)); HcclComm hcom = hcclComms[0]->getHcclComm(); @@ -1435,7 +1496,7 @@ c10::intrusive_ptr ProcessGroupHCCL::collective( } catch (std::exception& e) { throw std::runtime_error("Open shared directory failed. Please check whether perfdumppath is valid." + DIST_ERROR(ErrCode::NOT_FOUND)); } - + const std::vector& ranks = groupRanks(); outfile << "[GLOBAL RANKID]:" << ranks[rank_] << "\n"; diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp index 68194d9e0279014745c9406e819d70b39f70e4bf..66867e07f34bb5bf7f75a361618d6c850fe5314f 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp @@ -401,7 +401,9 @@ public: int64_t getHcclComm(int rankid); - std::string getHcclCommName(int rankid); + void setHcclCommName(const std::string& hccl_comm_name); + + std::string getHcclCommName(int rankid, bool init_comm = true); // Provides an API to abort the ProcessGroup (similar to hcclCommAbort) // instead of relying on ProcessGroupHCCL destructor. @@ -422,7 +424,8 @@ protected: std::vector>& getHCCLComm( const std::string& devicesKey, const std::vector& devices, - HcclCommType commType = HcclCommType::DEFAULT); + HcclCommType commType = HcclCommType::DEFAULT, + HcclCommConfig* commConfig = nullptr); // Get the data vol for HCCL operators. void recordDataVol(std::string opName, const std::string dataVol, const int currRank, @@ -441,6 +444,12 @@ protected: int rank, c10d::OpType opType); + // Do not call this directly, use ProcessGroup::setGroupName instead. + void setGroupName(const std::string& name) + { + pg_name_ = name; + } + static const int64_t kWatchdogThreadSleepMillis; // The store is used to broadcast the HCCL Master ID of rank 0. @@ -482,6 +491,8 @@ protected: // Note that the order of the device for the tensor list matters. std::unordered_map>> devHCCLCommMap_; + std::unordered_map devHCCLCommNameMap_; + // Mutex to guard maps like devHCCLCommMap_. std::mutex mutex_; @@ -581,6 +592,7 @@ protected: // Counting for the sequential number of HCCL collective call. uint64_t seq_{0}; + std::string pg_name_; std::exception_ptr watchDogException_ = nullptr; diff --git a/torch_npu/distributed/__init__.py b/torch_npu/distributed/__init__.py index dfa6cd9efa4cc59295ed78826110e6684eda772c..617af7e515006a8f2dc3d58c7817d4948b32ab95 100644 --- a/torch_npu/distributed/__init__.py +++ b/torch_npu/distributed/__init__.py @@ -24,6 +24,8 @@ if is_available() and not torch_npu._C._c10d_npu_init(): from torch_npu._C._distributed_c10d import ( _verify_params_across_processes, + _is_support_hccl_comm_name, ) + from .distributed_c10d import batch_isend_irecv, gather, gather_object, is_hccl_available