From 25c73733430196b2465efb3aa17bace1a0deeab4 Mon Sep 17 00:00:00 2001 From: yu-liang-bin Date: Wed, 16 Jul 2025 15:41:51 +0800 Subject: [PATCH] fix memory bug --- test/npu/test_npu_format.py | 49 ++++++++++ .../.pytorch-disabled-tests.json | 3 +- third_party/hccl/inc/hccl/hccl.h | 2 + third_party/hccl/inc/hccl/hccl_types.h | 4 +- third_party/op-plugin | 2 +- third_party/torchair/torchair | 2 +- torch_npu/__init__.py | 2 + .../csrc/core/npu/NPUCachingAllocator.cpp | 6 +- torch_npu/csrc/core/npu/NPUFunctions.cpp | 8 +- torch_npu/csrc/core/npu/NPUStream.cpp | 2 + .../csrc/core/npu/NPUWorkspaceAllocator.cpp | 10 +- .../csrc/core/npu/interface/AclInterface.cpp | 25 +++++ .../csrc/core/npu/interface/AclInterface.h | 4 + .../csrc/core/npu/register/OptionsManager.cpp | 6 +- .../csrc/distributed/ProcessGroupHCCL.cpp | 93 +++++++++++++++++++ .../csrc/distributed/ProcessGroupHCCL.hpp | 10 +- torch_npu/csrc/ipc/StorageSharing.cpp | 8 ++ torch_npu/csrc/npu/Module.cpp | 10 ++ torch_npu/csrc/profiler/npu_profiler.cpp | 5 +- torch_npu/csrc/profiler/npu_profiler.h | 15 ++- torch_npu/npu/_format.py | 38 ++++++++ 21 files changed, 279 insertions(+), 25 deletions(-) create mode 100644 test/npu/test_npu_format.py create mode 100644 torch_npu/npu/_format.py diff --git a/test/npu/test_npu_format.py b/test/npu/test_npu_format.py new file mode 100644 index 0000000000..2bc1c067ff --- /dev/null +++ b/test/npu/test_npu_format.py @@ -0,0 +1,49 @@ +import torch +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests + + +class TestNPUFormat(TestCase): + + def test_enum_values(self): + """test the enumeration value""" + self.assertEqual(torch_npu.Format.NCHW.value, 0) + self.assertEqual(torch_npu.Format.NHWC.value, 1) + + def test_npu_format_cast(self): + """test npu_format_cast""" + tensor = torch.ones(2, 2).npu() + + out1 = torch_npu.npu_format_cast(tensor, 0) + fmt1 = torch_npu.get_npu_format(out1) + self.assertEqual(fmt1, torch_npu.Format.NCHW) + + out2 = torch_npu.npu_format_cast(tensor, torch_npu.Format.NHWC) + fmt2 = torch_npu.get_npu_format(out2) + self.assertEqual(fmt2, torch_npu.Format.NHWC) + + def test_npu_format_cast_(self): + """test npu_format_cast_""" + x1 = torch.ones(2, 2).npu() + x2 = torch.ones(2, 2).npu() + + torch_npu.npu_format_cast_(x1, 0) + fmt1 = torch_npu.get_npu_format(x1) + self.assertEqual(fmt1, torch_npu.Format.NCHW) + + torch_npu.npu_format_cast_(x2, torch_npu.Format.NHWC) + fmt2 = torch_npu.get_npu_format(x2) + self.assertEqual(fmt2, torch_npu.Format.NHWC) + + def test_get_npu_format(self): + """test get_npu_format""" + x1 = torch.ones(2, 2).npu() + torch_npu.npu_format_cast_(x1, 0) + + fmt1 = torch_npu.get_npu_format(x1) + self.assertEqual(fmt1, torch_npu.Format.NCHW) + self.assertEqual(fmt1, 0) + + +if __name__ == "__main__": + run_tests() diff --git a/test/unsupported_test_cases/.pytorch-disabled-tests.json b/test/unsupported_test_cases/.pytorch-disabled-tests.json index b139132832..5872036a38 100644 --- a/test/unsupported_test_cases/.pytorch-disabled-tests.json +++ b/test/unsupported_test_cases/.pytorch-disabled-tests.json @@ -31601,5 +31601,6 @@ "test_fake_autocast_special_bessel_j1_npu_float32 (__main__.TestFakeTensorPRIVATEUSE1)": ["", [""]], "test_fake_autocast_nn_functional_multi_head_attention_forward_npu_float32 (__main__.TestFakeTensorPRIVATEUSE1)": ["", [""]], "test_fake_autocast_norm_nuc_npu_float32 (__main__.TestFakeTensorPRIVATEUSE1)": ["", [""]], - "test_fake_autocast_tanh_npu_float32 (__main__.TestFakeTensorPRIVATEUSE1)": ["", [""]] + "test_fake_autocast_tanh_npu_float32 (__main__.TestFakeTensorPRIVATEUSE1)": ["", [""]], + "test_unwaited (__main__.TestWithHccl)": ["", [""]] } diff --git a/third_party/hccl/inc/hccl/hccl.h b/third_party/hccl/inc/hccl/hccl.h index 023914a348..216ef7a838 100644 --- a/third_party/hccl/inc/hccl/hccl.h +++ b/third_party/hccl/inc/hccl/hccl.h @@ -212,6 +212,8 @@ inline void HcclCommConfigInit(HcclCommConfig *config) config->hcclRdmaTrafficClass = HCCL_COMM_TRAFFIC_CLASS_CONFIG_NOT_SET; config->hcclRdmaServiceLevel = HCCL_COMM_SERVICE_LEVEL_CONFIG_NOT_SET; config->hcclOpExpansionMode = HCCL_COMM_DEFAULT_OP_EXPANSION_MODE; + config->hcclWorldRankID = 0; + config->hcclJobID = 0; } /** diff --git a/third_party/hccl/inc/hccl/hccl_types.h b/third_party/hccl/inc/hccl/hccl_types.h index 40631676c1..9a02c61c04 100644 --- a/third_party/hccl/inc/hccl/hccl_types.h +++ b/third_party/hccl/inc/hccl/hccl_types.h @@ -15,7 +15,7 @@ extern "C" { const uint32_t HCCL_COMM_CONFIG_INFO_BYTES = 24; const uint32_t HCCL_COMM_CONFIG_MAGIC_WORD = 0xf0f0f0f0; -const uint32_t HCCL_COMM_CONFIG_VERSION = 5; +const uint32_t HCCL_COMM_CONFIG_VERSION = 6; const uint32_t HCCL_COMM_DEFAULT_BUFFSIZE = 200; // 200MB buffer size const uint32_t HCCL_COMM_DEFAULT_DETERMINISTIC = 0; // Disable deterministic calculations const uint32_t COMM_NAME_MAX_LENGTH = 128; @@ -132,6 +132,8 @@ typedef struct HcclCommConfigDef { uint32_t hcclOpExpansionMode; uint32_t hcclRdmaTrafficClass; uint32_t hcclRdmaServiceLevel; + uint32_t hcclWorldRankID; + uint64_t hcclJobID; } HcclCommConfig; typedef enum { diff --git a/third_party/op-plugin b/third_party/op-plugin index f8fab40561..8407b7cbb0 160000 --- a/third_party/op-plugin +++ b/third_party/op-plugin @@ -1 +1 @@ -Subproject commit f8fab40561b64047e20d2a98c7eac6f100cc71b6 +Subproject commit 8407b7cbb0c7046f80d006987170db775d637cc5 diff --git a/third_party/torchair/torchair b/third_party/torchair/torchair index ec5747ba54..e4bf05da76 160000 --- a/third_party/torchair/torchair +++ b/third_party/torchair/torchair @@ -1 +1 @@ -Subproject commit ec5747ba5477a4508131ca4401088e7383908266 +Subproject commit e4bf05da768a6d9a98e67f86c269e41a2369d02b diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index ce4b7cc6ad..9e7cb7b250 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -75,6 +75,7 @@ from torch_npu.utils import _apply_module_patch, _add_tensor_methods, _add_colle _apply_npu_show_warning, _apply_npugraph_tree_methods from torch_npu.utils._dynamo_device import _dynamo_register_interface_for_device from torch_npu.npu._stream_check import apply_sanitizer_patch +from torch_npu.npu._format import _apply_npu_format_patch import torch_npu.utils.custom_ops import torch_npu.distributed.rpc import torch_npu.op_plugin @@ -177,6 +178,7 @@ def _apply_class_patches(): _apply_fsdp_patch() _apply_npugraph_tree_methods() _add_reductions_methods() + _apply_npu_format_patch() def _apply_distributed_methods_patch(): diff --git a/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp index e3c3a327be..b81750695f 100644 --- a/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp +++ b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp @@ -1554,7 +1554,7 @@ public: stats.allocated_bytes[static_cast(StatType::AGGREGATE)].current, stats.reserved_bytes[static_cast(StatType::AGGREGATE)].current, stats.active_bytes[static_cast(StatType::AGGREGATE)].current, - reinterpret_cast(block->stream) }); + block->stream }); #endif return block; @@ -1619,7 +1619,7 @@ public: stats.allocated_bytes[static_cast(StatType::AGGREGATE)].current, stats.reserved_bytes[static_cast(StatType::AGGREGATE)].current, stats.active_bytes[static_cast(StatType::AGGREGATE)].current, - reinterpret_cast(block->stream) }); + block->stream }); #endif } @@ -2434,7 +2434,7 @@ private: stats.allocated_bytes[static_cast(StatType::AGGREGATE)].current, stats.reserved_bytes[static_cast(StatType::AGGREGATE)].current, stats.active_bytes[static_cast(StatType::AGGREGATE)].current, - reinterpret_cast(block->stream) }); + block->stream }); #endif } diff --git a/torch_npu/csrc/core/npu/NPUFunctions.cpp b/torch_npu/csrc/core/npu/NPUFunctions.cpp index 6de379e039..404a7ffe07 100644 --- a/torch_npu/csrc/core/npu/NPUFunctions.cpp +++ b/torch_npu/csrc/core/npu/NPUFunctions.cpp @@ -46,7 +46,6 @@ aclError GetDevice(int32_t *device) { if (targetDeviceIndex >= 0) { *device = targetDeviceIndex; - NPU_CHECK_ERROR_WITHOUT_UCE(SetDevice(targetDeviceIndex)); return ACL_ERROR_NONE; } @@ -60,13 +59,8 @@ aclError GetDevice(int32_t *device) } if (err == ACL_ERROR_NONE) { local_device = *device; - } else if (err == ACL_ERROR_RT_CONTEXT_NULL && aclrtSetDevice(0) == ACL_ERROR_NONE) { + } else if (err == ACL_ERROR_RT_CONTEXT_NULL) { *device = 0; - local_device = 0; - std::lock_guard lock(mtx); - if (used_devices.find(local_device) == used_devices.end()) { - NPU_CHECK_ERROR_WITHOUT_UCE(aclrtGetCurrentContext(&used_devices[local_device])); - } return ACL_ERROR_NONE; } return err; diff --git a/torch_npu/csrc/core/npu/NPUStream.cpp b/torch_npu/csrc/core/npu/NPUStream.cpp index 4411760ab4..cc8a53c54d 100644 --- a/torch_npu/csrc/core/npu/NPUStream.cpp +++ b/torch_npu/csrc/core/npu/NPUStream.cpp @@ -229,6 +229,8 @@ static void initNPUStreamsOnce() { // Inits default and secondary streams (once, globally) c10::DeviceIndex device_index = current_device(); + // makesure on real devcie + SetTargetDevice(); if (!initialize_flag[device_index]) { std::lock_guard lock(mtx[device_index]); if (!initialize_flag[device_index]) { diff --git a/torch_npu/csrc/core/npu/NPUWorkspaceAllocator.cpp b/torch_npu/csrc/core/npu/NPUWorkspaceAllocator.cpp index 7d5173dec8..660089b0fb 100644 --- a/torch_npu/csrc/core/npu/NPUWorkspaceAllocator.cpp +++ b/torch_npu/csrc/core/npu/NPUWorkspaceAllocator.cpp @@ -113,7 +113,7 @@ public: stats.allocated_bytes.current, stats.reserved_bytes.current, stats.allocated_bytes.current, - reinterpret_cast(stream)} + stream } ); #endif block->data_ptr = nullptr; @@ -154,7 +154,7 @@ public: stats.allocated_bytes.current, stats.reserved_bytes.current, stats.allocated_bytes.current, - reinterpret_cast(stream)} + stream } ); this->last_block = block; this->last_stream = stream; @@ -180,7 +180,7 @@ public: stats.allocated_bytes.current, stats.reserved_bytes.current, stats.allocated_bytes.current, - reinterpret_cast(stream)} + stream } ); this->last_block = block; this->last_stream = stream; @@ -204,7 +204,7 @@ public: stats.allocated_bytes.current, stats.reserved_bytes.current, stats.allocated_bytes.current, - reinterpret_cast(this->last_stream)} + this->last_stream } ); } #endif @@ -254,7 +254,7 @@ public: stats.allocated_bytes.current, stats.reserved_bytes.current, stats.allocated_bytes.current, - reinterpret_cast(block_pair.first)} + block_pair.first } ); #endif } diff --git a/torch_npu/csrc/core/npu/interface/AclInterface.cpp b/torch_npu/csrc/core/npu/interface/AclInterface.cpp index 779040d4c7..000ff651ae 100644 --- a/torch_npu/csrc/core/npu/interface/AclInterface.cpp +++ b/torch_npu/csrc/core/npu/interface/AclInterface.cpp @@ -89,6 +89,8 @@ LOAD_FUNCTION(aclrtIpcMemClose) LOAD_FUNCTION(aclrtMemExportToShareableHandle) LOAD_FUNCTION(aclrtMemSetPidToShareableHandle) LOAD_FUNCTION(aclrtMemImportFromShareableHandle) +LOAD_FUNCTION(aclrtDeviceGetBareTgid) +LOAD_FUNCTION(aclrtStreamGetId) aclprofStepInfoPtr init_stepinfo() { typedef aclprofStepInfoPtr(*npdInitFunc)(); @@ -1020,5 +1022,28 @@ aclError AclrtMemImportFromShareableHandle(uint64_t shareableHandle, int32_t dev return func(shareableHandle, deviceId, handle); } +aclError AclrtDeviceGetBareTgid(int32_t *pid) +{ + typedef aclError (*AclrtDeviceGetBareTgid)(int32_t *); + static AclrtDeviceGetBareTgid func = nullptr; + if (func == nullptr) { + func = (AclrtDeviceGetBareTgid) GET_FUNC(aclrtDeviceGetBareTgid); + } + + TORCH_CHECK(func, "Failed to find function aclrtDeviceGetBareTgid", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(pid); +} + +aclError AclrtStreamGetId(aclrtStream stream, int32_t* stream_id) +{ + typedef aclError(*AclrtStreamGetIdFunc)(aclrtStream, int32_t*); + static AclrtStreamGetIdFunc func = nullptr; + if (func == nullptr) { + func = (AclrtStreamGetIdFunc)GET_FUNC(aclrtStreamGetId); + } + TORCH_CHECK(func, "Failed to find function ", "AclrtStreamGetId", PROF_ERROR(ErrCode::NOT_FOUND)); + return func(stream, stream_id); +} + } // namespace acl } // namespace c10 diff --git a/torch_npu/csrc/core/npu/interface/AclInterface.h b/torch_npu/csrc/core/npu/interface/AclInterface.h index f2c991b19f..e04159afca 100644 --- a/torch_npu/csrc/core/npu/interface/AclInterface.h +++ b/torch_npu/csrc/core/npu/interface/AclInterface.h @@ -243,5 +243,9 @@ aclError AclrtMemSetPidToShareableHandle(uint64_t shareableHandle, int32_t *pid, aclError AclrtMemImportFromShareableHandle(uint64_t shareableHandle, int32_t deviceId, aclrtDrvMemHandle *handle); +aclError AclrtDeviceGetBareTgid(int32_t *pid); + +aclError AclrtStreamGetId(aclrtStream stream, int32_t* stream_id); + } // namespace acl } // namespace c10_npu diff --git a/torch_npu/csrc/core/npu/register/OptionsManager.cpp b/torch_npu/csrc/core/npu/register/OptionsManager.cpp index 228ba340da..5851665c64 100644 --- a/torch_npu/csrc/core/npu/register/OptionsManager.cpp +++ b/torch_npu/csrc/core/npu/register/OptionsManager.cpp @@ -482,11 +482,11 @@ uint32_t OptionsManager::GetAclOpInitMode() const static uint32_t acl_op_init_mode = []() -> uint32_t { char* buf_val = std::getenv("ACL_OP_INIT_MODE"); // Default 0 - int64_t acl_op_init_mode = (buf_val != nullptr) ? strtol(buf_val, nullptr, 10) : 1; + int64_t acl_op_init_mode = (buf_val != nullptr) ? strtol(buf_val, nullptr, 10) : 0; std::unordered_map aclOpInitMode = getAclOpInitMode(); if (aclOpInitMode.find(acl_op_init_mode) == aclOpInitMode.end()) { - acl_op_init_mode = 1; - TORCH_NPU_WARN_ONCE("Get env ACL_OP_INIT_MODE not in [0, 1, 2], so reset it to the default value 1."); + acl_op_init_mode = 0; + TORCH_NPU_WARN_ONCE("Get env ACL_OP_INIT_MODE not in [0, 1, 2], so reset it to the default value 0."); } return static_cast(acl_op_init_mode); }(); diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index ce1d6e7c7f..244f6ef211 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -19,8 +19,12 @@ #include #include #include +#include +#include #include +#include + #include "op_plugin/OpInterface.h" #include "third_party/acl/inc/acl/acl.h" #include "third_party/acl/inc/acl/acl_base.h" @@ -63,6 +67,7 @@ constexpr const char* P2P_DEVICE_KEY = "_p2p"; using hcclUs = std::chrono::steady_clock::time_point; constexpr int32_t MAX_GROUP_NAME_LEN = 128; +constexpr int32_t NSLB_JOBID_OFFSET = 32; // HCCL ReduceOp mapping std::map hcclOp = { @@ -950,6 +955,24 @@ ProcessGroupHCCL::ProcessGroupHCCL( c10d::PrefixStore *prefixStore = dynamic_cast(store_.get()); globalStore_ = prefixStore ? prefixStore->getUnderlyingNonPrefixStore() : store_; + c10::intrusive_ptr getTcpStore = store_; + while (getTcpStore) { + c10d::PrefixStore *asPrefixStore = dynamic_cast(getTcpStore.get()); + c10d::TCPStore *tcpStore = dynamic_cast(getTcpStore.get()); + if (tcpStore) { + if (!(tcpStore->getHost().empty())) { + tcpMasterAddr = tcpStore->getHost(); + tcpMasterPort = tcpStore->getPort(); + break; + } + } + if (asPrefixStore) { + getTcpStore = asPrefixStore->getUnderlyingStore(); + } else { + break; + } + } + try { if (blockingWait != nullptr) { auto val = std::stoi(blockingWait); @@ -1181,6 +1204,7 @@ void ProcessGroupHCCL::abortAndClearHcclComm(c10::optional abortRea abortCommsFromMap(devHCCLCommMap_, rank_, abortReason); devHCCLCommMap_.clear(); devHCCLCommNameMap_.clear(); + p2pSendRecvKeys_.clear(); hcclCommCounter_ = 0; return; } @@ -1223,6 +1247,7 @@ ProcessGroupHCCL::~ProcessGroupHCCL() } } devHCCLCommMap_.clear(); + p2pSendRecvKeys_.clear(); } ASCEND_LOGI("process group destroyed, group id is %s.", options_->group_id.c_str()); logger->info("process group destroyed, group id is %s.", options_->group_id.c_str()); @@ -2155,6 +2180,30 @@ std::vector>& ProcessGroupHCCL::getHCCLComm( return createHCCLComm(devicesKey, devices, commType, commConfig, p2pRank); } +void ProcessGroupHCCL::setNSLBCommConfig(HcclCommConfig** commConfig) +{ + const char* envPtr = std::getenv("RANK"); + if (envPtr == nullptr) { + ASCEND_LOGI("Failed to get env info for NSLB-DP."); + return; + } + uint32_t worldRankID = std::stoi(std::string(envPtr)); + options_->hccl_config["hccl_world_rank_id"] = worldRankID; + uint32_t masterPort = tcpMasterPort; + struct sockaddr_in sa; + std::string master_addr = tcpMasterAddr; + inet_pton(AF_INET, std::string(master_addr).c_str(), &(sa.sin_addr)); + uint32_t masterIp = ntohl(sa.sin_addr.s_addr); + uint64_t jobID = masterPort; + jobID = (jobID << NSLB_JOBID_OFFSET); + jobID += masterIp; + options_->hccl_config["hccl_job_id"] = jobID; + if ((*commConfig) != nullptr) { + (*commConfig)->hcclWorldRankID = worldRankID; + (*commConfig)->hcclJobID = jobID; + } +} + void ProcessGroupHCCL::createHCCLComm( const std::string& devicesKey, const std::vector& devices, @@ -2179,6 +2228,10 @@ void ProcessGroupHCCL::createHCCLComm( HcclCommConfig config; + if (options_->global_ranks_in_group.empty()) { + setNSLBCommConfig(&commConfig); + } + npuGuard.set_index(devices[i].index()); switch (commType) { case HcclCommType::DEFAULT: @@ -2309,6 +2362,9 @@ bool ProcessGroupHCCL::createHCCLCommEx( return false; } hcclComms[i] = subComm; + if (commType == HcclCommType::P2P) { + hcclComms[i]->p2pPeer = getP2pPeer(); + } // Creates the HCCL streams streamVal.push_back(getNPUStreamByCurrentType(devices[i].index())); } @@ -2411,6 +2467,14 @@ std::vector>& ProcessGroupHCCL::createHCCLComm( // Move the HCCL resource to cache devHCCLCommMap_.emplace(devicesKey, std::move(hcclComms)); + if (commType == HcclCommType::P2P) { + auto iter = p2pSendRecvKeys_.find(rank_); + if (iter == p2pSendRecvKeys_.end()) { + p2pSendRecvKeys_.emplace(rank_, std::vector{devicesKey}); + } else { + iter->second.push_back(devicesKey); + } + } return devHCCLCommMap_[devicesKey]; } @@ -2767,6 +2831,19 @@ void ProcessGroupHCCL::resumeHcclComm(int device_id) HCCL_CHECK_ERROR(at_npu::hccl::HcclCommResumeFace(comm)); } } + if (p2pSendRecvKeys_.find(rank_) != p2pSendRecvKeys_.end()) { + auto p2pKeys = p2pSendRecvKeys_[rank_]; + for (const auto& p2pKey : p2pKeys) { + if (devHCCLCommMap_.find(p2pKey) != devHCCLCommMap_.end()) { + // Reuse the cached communicator if there is one. + auto& hcclComms = devHCCLCommMap_[p2pKey]; + for (const auto& hcclComm : hcclComms) { + auto comm = hcclComm->getHcclComm(); + HCCL_CHECK_ERROR(at_npu::hccl::HcclCommResumeFace(comm)); + } + } + } + } } ASCEND_LOGI("resumeHcclComm success, group id is %s.", options_->group_id.c_str()); } @@ -3107,6 +3184,22 @@ HcclCommConfig ProcessGroupHCCL::createHcclCommConfigWithOptions() } } + if (options_->hccl_config.find("hccl_world_rank_id") != options_->hccl_config.end()) { + if (std::holds_alternative(options_->hccl_config["hccl_world_rank_id"])) { + config.hcclOpExpansionMode = std::get(options_->hccl_config["hccl_world_rank_id"]); + } else { + TORCH_CHECK(false, "Value type of hccl_world_rank_id should be int.", DIST_ERROR(ErrCode::TYPE)); + } + } + + if (options_->hccl_config.find("hccl_job_id") != options_->hccl_config.end()) { + if (std::holds_alternative(options_->hccl_config["hccl_job_id"])) { + config.hcclOpExpansionMode = std::get(options_->hccl_config["hccl_job_id"]); + } else { + TORCH_CHECK(false, "Value type of hccl_job_id should be int.", DIST_ERROR(ErrCode::TYPE)); + } + } + return config; } diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp index 9c2f365b3e..057afe5ccb 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp @@ -384,7 +384,7 @@ public: return c10::make_intrusive(_is_high_priority_stream); } - std::unordered_map> hccl_config; + std::unordered_map> hccl_config; std::chrono::milliseconds opTimeout; // Schedule HCCL operations on high priority CUDA streams @@ -571,6 +571,8 @@ public: void resumeHcclComm(int device_id); + void setNSLBCommConfig(HcclCommConfig** commConfig); + bool setCommWorkingDevNic( const HcclComm& comm, int nranks, @@ -746,6 +748,8 @@ protected: // // Note that the order of the device for the tensor list matters. std::unordered_map>> devHCCLCommMap_; + + std::unordered_map> p2pSendRecvKeys_; std::unordered_map devHCCLCommNameMap_; @@ -960,6 +964,10 @@ protected: std::string pg_desc_; + std::string tcpMasterAddr; + + uint32_t tcpMasterPort; + private: // Helper that encapsulates work shared across all collective communication // primitives. diff --git a/torch_npu/csrc/ipc/StorageSharing.cpp b/torch_npu/csrc/ipc/StorageSharing.cpp index 1169cbd1c5..18fdd4c5e0 100644 --- a/torch_npu/csrc/ipc/StorageSharing.cpp +++ b/torch_npu/csrc/ipc/StorageSharing.cpp @@ -14,6 +14,8 @@ #include "torch_npu/csrc/core/NPUBridge.h" #include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include "torch_npu/csrc/core/npu/NPUGuard.h" +#include "torch_npu/csrc/core/NPUStorageImpl.h" +#include "torch_npu/csrc/framework/FormatHelper.h" #include "torch_npu/csrc/ipc/NPUIPCTypes.h" #include "torch_npu/csrc/ipc/StorageSharing.h" @@ -33,6 +35,12 @@ static PyObject* THNPStorage_shareNpu(PyObject* self, PyObject* args) "_share_npu_: only available on NPU.", PTA_ERROR(ErrCode::PARAM)); c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl(); + auto npu_storage_impl = static_cast(storage.unsafeGetStorageImpl()); + auto format = npu_storage_impl->npu_desc_.npu_format_; + TORCH_CHECK(at_npu::native::FormatHelper::IsBaseFormatType(format), + "Try to share a storage without base format", + PTA_ERROR(ErrCode::TYPE)); + if (storage_impl->received_cuda()) { AT_ERROR( "Supported to send NPU tensor received from another process; other is not currently supported. Consider cloning before sending."); diff --git a/torch_npu/csrc/npu/Module.cpp b/torch_npu/csrc/npu/Module.cpp index b168963fa0..d335acc4e3 100644 --- a/torch_npu/csrc/npu/Module.cpp +++ b/torch_npu/csrc/npu/Module.cpp @@ -1702,6 +1702,15 @@ static PyObject* THNPModule_add_ipc_pid(PyObject* self, PyObject *args) END_HANDLE_TH_ERRORS } +static PyObject* THNPModule_get_ipc_pid(PyObject* self, PyObject *noargs) +{ + HANDLE_TH_ERRORS + int32_t pid; + NPU_CHECK_ERROR(c10_npu::acl::AclrtDeviceGetBareTgid(&pid)); + return THPUtils_packInt32(pid); + END_HANDLE_TH_ERRORS +} + static PyObject* THNPModule_add_p2p_access(PyObject* self, PyObject *args) { HANDLE_TH_ERRORS @@ -1779,6 +1788,7 @@ static struct PyMethodDef THNPModule_methods[] = { {"_get_cann_version", (PyCFunction)THNPModule_get_cann_version, METH_O, nullptr}, {"_is_gte_cann_version", (PyCFunction)THNPModule_is_gte_cann_version, METH_VARARGS, nullptr}, {"_add_ipc_pid", (PyCFunction)THNPModule_add_ipc_pid, METH_VARARGS, nullptr}, + {"_get_ipc_pid", (PyCFunction)THNPModule_get_ipc_pid, METH_NOARGS, nullptr}, {"_add_p2p_access", (PyCFunction)THNPModule_add_p2p_access, METH_VARARGS, nullptr}, {nullptr}}; diff --git a/torch_npu/csrc/profiler/npu_profiler.cpp b/torch_npu/csrc/profiler/npu_profiler.cpp index 295eda9aea..3678da0755 100644 --- a/torch_npu/csrc/profiler/npu_profiler.cpp +++ b/torch_npu/csrc/profiler/npu_profiler.cpp @@ -6,6 +6,7 @@ #include "torch_npu/csrc/core/npu/npu_log.h" #include "torch_npu/csrc/core/npu/NPUException.h" +#include "torch_npu/csrc/core/npu/interface/AclInterface.h" #include "torch_npu/csrc/profiler/npu_profiler.h" #include "torch_npu/csrc/toolkit/profiler/common/utils.h" @@ -380,6 +381,8 @@ void reportMemoryDataToNpuProfiler(const MemoryUsage& data) if (!ProfilerMgr::GetInstance()->ReportMemEnable().load()) { return; } + int32_t stream_id; + c10_npu::acl::AclrtStreamGetId(data.stream, &stream_id); ProfilerMgr::GetInstance()->UploadWithLock(std::make_unique( data.ptr, static_cast(Utils::GetClockTime()), @@ -387,7 +390,7 @@ void reportMemoryDataToNpuProfiler(const MemoryUsage& data) data.total_allocated, data.total_reserved, data.total_active, - data.stream_ptr, + stream_id, data.device_type, data.device_index, data.component_type, diff --git a/torch_npu/csrc/profiler/npu_profiler.h b/torch_npu/csrc/profiler/npu_profiler.h index 33d0a8cf92..bde3462b75 100644 --- a/torch_npu/csrc/profiler/npu_profiler.h +++ b/torch_npu/csrc/profiler/npu_profiler.h @@ -7,6 +7,9 @@ #include +#include "third_party/acl/inc/acl/acl_base.h" +#include "third_party/acl/inc/acl/acl_rt.h" + #include "torch_npu/csrc/toolkit/profiler/inc/data_reporter.h" #include "torch_npu/csrc/profiler/profiler_mgr.h" #include "torch_npu/csrc/profiler/mstx_mgr.h" @@ -55,7 +58,17 @@ struct MemoryUsage { int64_t total_allocated{ 0 }; int64_t total_reserved{ 0 }; int64_t total_active{ 0 }; - int64_t stream_ptr{ 0 }; + aclrtStream stream{nullptr}; + int8_t device_type{0}; + int8_t device_index{0}; + uint8_t component_type{static_cast(MemoryComponentType::CACHING_ALLOCATOR)}; + uint8_t data_type{static_cast(MemoryDataType::MEMORY_INVALID)}; + uint8_t allocator_type{static_cast(MemoryAllocatorType::ALLOCATOR_INVALID)}; + int64_t ptr{0}; + int64_t alloc_size{0}; + int64_t total_allocated{0}; + int64_t total_reserved{0}; + int64_t total_active{0}; }; struct ExperimentalConfig { diff --git a/torch_npu/npu/_format.py b/torch_npu/npu/_format.py new file mode 100644 index 0000000000..beb65e076f --- /dev/null +++ b/torch_npu/npu/_format.py @@ -0,0 +1,38 @@ +from enum import IntEnum + +import torch +import torch_npu + + +class Format(IntEnum): + """NPU storage format enumeration class""" + UNDEFINED = -1 + NCHW = 0 + NHWC = 1 + ND = 2 + NC1HWC0 = 3 + FRACTAL_Z = 4 + NC1HWC0_C04 = 12 + HWCN = 16 + NDHWC = 27 + FRACTAL_NZ = 29 + NCDHW = 30 + NDC1HWC0 = 32 + FRACTAL_Z_3D = 33 + NC = 35 + NCL = 47 + + def __str__(self): + return self.name + + +def _apply_npu_format_patch(): + orig_get_format = torch_npu.get_npu_format + + def patched_get_format(tensor): + """get the Format type of tensor""" + format_int = orig_get_format(tensor) + return Format(format_int) + + torch_npu.get_npu_format = patched_get_format + torch_npu.Format = Format -- Gitee