diff --git a/CMakeLists.txt b/CMakeLists.txt index b88ceef030e47eee2a3a864443443b9917cb8c02..9094b223c3fe9510be6b4015a7214b2eaaaa2b6b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -235,6 +235,7 @@ if (NOT DEFINED BUILD_LIBTORCH) set(FLOP_SRCS) set(NPU_SRCS) set(PROF_SRCS) + set(IPC_SRCS) set(UTILS_SRCS) set(SAN_SRCS) endif() @@ -254,6 +255,7 @@ if (NOT DEFINED BUILD_LIBTORCH) add_subdirectory(${TORCHNPU_ROOT}/distributed) add_subdirectory(${TORCHNPU_ROOT}/npu) add_subdirectory(${TORCHNPU_ROOT}/profiler) + add_subdirectory(${TORCHNPU_ROOT}/ipc) add_subdirectory(${TORCHNPU_ROOT}/utils) add_subdirectory(${TORCHNPU_ROOT}/sanitizer) endif() @@ -285,7 +287,7 @@ if (DEFINED BUILD_LIBTORCH) set(CPP_SRCS ${ATEN_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${FLOP_SRCS} ${CUS_DTYPE_SRCS} ${FRAMEWORK_SRCS} ${LOGGING_SRCS} ${NPU_CPP_LIBS_SRCS}) else() # Compile code with pybind11 - set(CPP_SRCS ${ATEN_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${DIST_SRCS} ${FLOP_SRCS} ${CUS_DTYPE_SRCS} ${LOGGING_SRCS} ${FRAMEWORK_SRCS} ${NPU_SRCS} ${PROF_SRCS} ${UTILS_SRCS} ${SAN_SRCS}) + set(CPP_SRCS ${ATEN_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${DIST_SRCS} ${FLOP_SRCS} ${CUS_DTYPE_SRCS} ${LOGGING_SRCS} ${FRAMEWORK_SRCS} ${NPU_SRCS} ${PROF_SRCS} ${IPC_SRCS} ${UTILS_SRCS} ${SAN_SRCS}) endif() add_library(${PLUGIN_NAME} SHARED ${CPP_SRCS}) diff --git a/env.sh b/env.sh index ff54b797d211caad86b37132a8fdc101157c1388..96fa71d80f4f94d140314654b82bfe8fa0f469c2 100644 --- a/env.sh +++ b/env.sh @@ -1,3 +1,4 @@ +#!/bin/bash # 配置CANN相关环境变量 CANN_INSTALL_PATH_CONF='/etc/Ascend/ascend_cann_install.info' diff --git a/test/allocator/test_pluggable_allocator_extensions.py b/test/allocator/test_pluggable_allocator_extensions.py index 99cc499a93c457b0c6732dd3de015c76a280c695..54e270513d3031419695c0488a75916309e2ae30 100644 --- a/test/allocator/test_pluggable_allocator_extensions.py +++ b/test/allocator/test_pluggable_allocator_extensions.py @@ -2,6 +2,7 @@ import os import sys import shutil import subprocess +import ctypes import torch import torch.utils.cpp_extension @@ -27,6 +28,7 @@ def build_stub(base_dir): class TestPluggableAllocator(TestCase): module = None + new_alloc = None build_directory = "allocator/build" @classmethod @@ -59,9 +61,9 @@ class TestPluggableAllocator(TestCase): def test_pluggable_allocator(self): os_path = os.path.join(TestPluggableAllocator.build_directory, 'pluggable_allocator_extensions.so') # Load the allocator - new_alloc = torch_npu.npu.memory.NPUPluggableAllocator(os_path, 'my_malloc', 'my_free') + TestPluggableAllocator.new_alloc = torch_npu.npu.memory.NPUPluggableAllocator(os_path, 'my_malloc', 'my_free') # Swap the current allocator - torch_npu.npu.memory.change_current_allocator(new_alloc) + torch_npu.npu.memory.change_current_allocator(TestPluggableAllocator.new_alloc) # This will allocate memory in the device using the new allocator self.assertFalse(self.module.check_custom_allocator_used()) npu_tensor = torch.zeros(10, device='npu') @@ -69,6 +71,31 @@ class TestPluggableAllocator(TestCase): self.assertRtolEqual(npu_tensor.cpu().numpy(), cpu_tensor.numpy()) self.assertTrue(self.module.check_custom_allocator_used()) + def test_set_get_device_stats_fn(self): + os_path = os.path.join(TestPluggableAllocator.build_directory, 'pluggable_allocator_extensions.so') + myallocator = ctypes.CDLL(os_path) + get_device_stats_fn = ctypes.cast(getattr(myallocator, "my_get_device_stats"), ctypes.c_void_p).value + + msg = "get_device_stats_fn_ is not define, please set by set_get_device_stats_fn" + with self.assertRaisesRegex(RuntimeError, msg): + torch.npu.memory_stats_as_nested_dict() + + TestPluggableAllocator.new_alloc.allocator().set_get_device_stats_fn(get_device_stats_fn) + self.assertEqual(torch.npu.memory_stats_as_nested_dict()["num_alloc_retries"], 0) + + def test_set_reset_peak_status_fn(self): + os_path = os.path.join(TestPluggableAllocator.build_directory, 'pluggable_allocator_extensions.so') + myallocator = ctypes.CDLL(os_path) + reset_peak_status_fn = ctypes.cast(getattr(myallocator, "my_reset_peak_status"), ctypes.c_void_p).value + + msg = "reset_peak_status_fn_ is not define, please set by set_reset_peak_status_fn" + with self.assertRaisesRegex(RuntimeError, msg): + torch.npu.reset_peak_memory_stats() + + TestPluggableAllocator.new_alloc.allocator().set_reset_peak_status_fn(reset_peak_status_fn) + torch.npu.reset_peak_memory_stats() + self.assertEqual(torch.npu.max_memory_allocated(), 0) + def test_pluggable_allocator_after_init(self): os_path = os.path.join(TestPluggableAllocator.build_directory, 'pluggable_allocator_extensions.so') # Do an initial memory allocator diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index b91d8fb67e26a915eb9b33fedca2fe328dfc4feb..adeaa84860d04f788637b4f13b12ebd6c5a5f0d9 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -584,7 +584,9 @@ "ForkingPickler", "Union", "check_serializing_named_tensor", - "register_after_fork" + "register_after_fork", + "reduce_tensor", + "reduce_storage" ], "torch.multiprocessing.spawn": [ "Optional" diff --git a/test/cpp_extensions/pluggable_allocator_extensions.cpp b/test/cpp_extensions/pluggable_allocator_extensions.cpp index 3ed2606b021ba7796ed6e94ad11f41625a88d169..6bb80e59dd5c4911d79fcb50cadc69b6f6babdbb 100644 --- a/test/cpp_extensions/pluggable_allocator_extensions.cpp +++ b/test/cpp_extensions/pluggable_allocator_extensions.cpp @@ -4,8 +4,10 @@ #include "third_party/acl/inc/acl/acl_base.h" #include "third_party/acl/inc/acl/acl_rt.h" +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" extern "C" { +using c10_npu::NPUCachingAllocator::DeviceStats; static bool useflag = false; void* my_malloc(ssize_t size, int device, aclrtStream stream) @@ -27,6 +29,17 @@ bool check_custom_allocator_used() { return useflag; } + +DeviceStats my_get_device_stats(int device) +{ + DeviceStats stats; + return stats; +} + +void my_reset_peak_status(int device) +{ + std::cout<<"resetPeakStatus success!"<(StatType::NUM_TYPES)>; void update_stat(Stat &stat, int64_t amount) @@ -355,7 +362,10 @@ bevhavior for allocator tensors that need to be used cross-process. */ struct ExpandableSegment { - ExpandableSegment(int device, aclrtStream stream, size_t size) + ExpandableSegment( + int device, + std::optional stream, + size_t size) : device_(device), stream_(stream), max_handles_(0), @@ -376,7 +386,7 @@ struct ExpandableSegment { auto default_stream = c10_npu::getDefaultNPUStream().stream(false); if (kSmallBuffer == segment_size_) { max_handles_ = numSegments(kSmallPoolVirAddrSize); - } else if (default_stream != stream) { + } else if (default_stream != *stream) { max_handles_ = numSegments(kLargePoolVirAddrSize); } } @@ -416,17 +426,17 @@ struct ExpandableSegment { for (auto j : c10::irange(begin, i)) { auto h = handles_.at(j).value(); handles_.at(j) = c10::nullopt; - NPU_CHECK_ERROR(c10_npu::acl::AclrtFreePhysical(h)); + NPU_CHECK_ERROR(c10_npu::acl::AclrtFreePhysical(h.handle)); } trimHandles(); return rangeFromHandles(begin, begin); } NPU_CHECK_ERROR(status, "aclrtMallocPhysical"); - handles_.at(i) = handle; + handles_.at(i) = Handle{handle, std::nullopt}; } for (auto i : c10::irange(begin, end)) { NPU_CHECK_ERROR(c10_npu::acl::AclrtMapMem((char *)ptr_ + i * segment_size_, segment_size_, 0, - handles_.at(i).value(), 0, getHcclComm())); + handles_.at(i).value().handle, 0, getHcclComm())); } ASCEND_LOGD("NPUCachingAllocator map: segment_size=%zu", segment_size_); return rangeFromHandles(begin, end); @@ -446,6 +456,59 @@ struct ExpandableSegment { return rangeFromHandles(begin, end); } + // Setup IPC sharing for range. + // Returns the (larger) range that was actually shared. + // Serializes data to std::ostream that can be passed to the + // other process, and then restored as an exapandable segment + // via ExpandableSegment::fromShared(istream); + SegmentRange share(SegmentRange range, std::ostream& buf) + { + auto begin = segmentLeft(range.ptr); + auto end = segmentRight(range.ptr + range.size); + ShareHeader header{segment_size_, end - begin}; + buf.write((const char*)&header, sizeof(ShareHeader)); + for (auto i : c10::irange(begin, end)) { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + auto& handle = handles_.at(i).value(); + if (!handle.shareableHandle) { + uint64_t shareableHandle = 0; + NPU_CHECK_ERROR(c10_npu::acl::AclrtMemExportToShareableHandle( + handle.handle, ACL_MEM_HANDLE_TYPE_NONE, 0, &shareableHandle)); + int32_t* pids = nullptr; + int pid_num = torch_npu::ipc::getPids(&pids); + NPU_CHECK_ERROR(c10_npu::acl::AclrtMemSetPidToShareableHandle(shareableHandle, pids, pid_num)); + handle.shareableHandle = shareableHandle; + } + uint64_t shandle = *handle.shareableHandle; + buf.write((const char*)&shandle, sizeof(uint64_t)); + } + return rangeFromHandles(begin, end); + } + + static std::unique_ptr fromShared( + c10::DeviceIndex device, + std::istream& buf) + { + ShareHeader header{}; + buf.read((char*)&header, sizeof(ShareHeader)); + auto segment = std::make_unique( + device, + std::nullopt, + header.segment_size); + for (auto i : c10::irange(header.num_handles)) { + (void)i; + uint64_t shareableHandle = 0; + buf.read((char*)&shareableHandle, sizeof(uint64_t)); + int32_t deviceId = static_cast(device); + aclrtDrvMemHandle handle; + NPU_CHECK_ERROR(c10_npu::acl::AclrtMemImportFromShareableHandle( + shareableHandle, deviceId, &handle)); + segment->handles_.emplace_back(Handle{handle, std::nullopt}); + } + segment->mapAndSetAccess(0, header.num_handles); + return segment; + } + char *ptr() const { return (char *)ptr_; @@ -464,7 +527,7 @@ struct ExpandableSegment { segment_size_ * max_handles_, 0, 1)); for (auto i : c10::irange(handles_.size())) { HCCL_CHECK_ERROR(at_npu::hccl::HcclCommActivateCommMemoryFace(hcclComm_->getHcclComm(), - (char *)ptr_ + i * segment_size_, segment_size_, 0, handles_.at(i).value(), 0)); + (char *)ptr_ + i * segment_size_, segment_size_, 0, handles_.at(i).value().handle, 0)); } } @@ -476,6 +539,15 @@ struct ExpandableSegment { } private: + void mapAndSetAccess(size_t begin, size_t end) + { + for (auto i : c10::irange(begin, end)) { + NPU_CHECK_ERROR(c10_npu::acl::AclrtMapMem((char *)ptr_ + i * segment_size_, segment_size_, 0, + handles_.at(i).value().handle, 0, getHcclComm())); + } + ASCEND_LOGD("NPUCachingAllocator mapAndSetAccess: segment_size=%zu", segment_size_); + } + void unmapHandles(size_t begin, size_t end) { // note: unlike aclrtFree, MemUnmap and MemRelease do @@ -485,18 +557,23 @@ private: // cannot call c10::npu::stream_synchronize because // it might grab the GIL which can lead to a deadlock // Locking order must be GIL -> Allocator Lock - NPU_CHECK_ERROR(aclrtSynchronizeStream(stream_)); + if (stream_) { + NPU_CHECK_ERROR(aclrtSynchronizeStream(*stream_)); + } else { + c10_npu::NPUGuard device_guard(device_); + c10_npu::npuSynchronizeDevice(true); + } #ifndef BUILD_LIBTORCH const c10_npu::impl::PyCallbackTrigger *trigger = c10_npu::impl::NPUTrace::getTrace(); if (C10_UNLIKELY(trigger)) { - trigger->traceNpuStreamSynchronization(reinterpret_cast(stream_)); + trigger->traceNpuStreamSynchronization(reinterpret_cast(*stream_)); } #endif for (auto i : c10::irange(begin, end)) { - aclrtDrvMemHandle h = handles_.at(i).value(); + Handle h = handles_.at(i).value(); handles_.at(i) = c10::nullopt; NPU_CHECK_ERROR(c10_npu::acl::AclrtUnmapMem((char *)ptr_ + segment_size_ * i, getHcclComm())); - NPU_CHECK_ERROR(c10_npu::acl::AclrtFreePhysical(h)); + NPU_CHECK_ERROR(c10_npu::acl::AclrtFreePhysical(h.handle)); } ASCEND_LOGD("NPUCachingAllocator unmap: segment_size=%zu", segment_size_); trimHandles(); @@ -553,11 +630,19 @@ private: } int device_; - aclrtStream stream_; + std::optional stream_; void *ptr_{}; size_t max_handles_; size_t segment_size_; - std::vector> handles_; + struct Handle { + aclrtDrvMemHandle handle; + std::optional shareableHandle; + }; + struct ShareHeader { + size_t segment_size; + size_t num_handles; + }; + std::vector> handles_; std::shared_ptr hcclComm_; }; @@ -1014,6 +1099,13 @@ private: std::unique_lock& lock_; }; +struct handle_str { + char data[ACL_IPC_HANDLE_SIZE]; +}; + +// handle for ptr +ska::flat_hash_map ipc_handle_map; + class DeviceCachingAllocator { private: // lock around all operations @@ -1549,6 +1641,40 @@ public: return basePtr; } + ShareableHandle shareIpcHandle(Block* block) + { + std::lock_guard lock(mutex); + std::ostringstream ss; + ss.put(SHAREABLE_HANDLE_VERSION); + ptrdiff_t offset = 0; + if (!block->expandable_segment_) { + ss.put(SHAREABLE_NPU_MALLOC); + size_t base_size; + void* base_ptr = getBaseAllocation(block, &base_size); + offset = (char*)block->ptr - (char*)base_ptr; + + handle_str handle; + auto it = ipc_handle_map.find(base_ptr); + if (it == ipc_handle_map.end()) { + NPU_CHECK_ERROR(c10_npu::acl::AclrtIpcMemGetExportKey( + base_ptr, base_size, handle.data, ACL_IPC_HANDLE_SIZE)); + int32_t* pids = nullptr; + int pid_num = torch_npu::ipc::getPids(&pids); + NPU_CHECK_ERROR(c10_npu::acl::AclrtIpcMemSetImportPid(handle.data, pids, pid_num)); + ipc_handle_map[base_ptr] = handle; + } else { + handle = it->second; + } + ss.write((char*)&handle, ACL_IPC_HANDLE_SIZE); + } else { + ss.put(SHAREABLE_NPU_EXPANDABLE_SEGMENT); + auto full_range = block->expandable_segment_->share( + SegmentRange(block->ptr, block->size), ss); + offset = (char*)block->ptr - (char*)full_range.ptr; + } + return ShareableHandle{offset, ss.str()}; + } + void recordStream(Block *block, c10_npu::NPUStream stream) { std::lock_guard lock(mutex); @@ -2703,6 +2829,12 @@ private: record_trace(TraceEntry::SEGMENT_FREE, int64_t(block->ptr), block->size, block->stream, block->device, context ? context : block->context_when_segment_allocated); + auto it = ipc_handle_map.find(block->ptr); + if (it != ipc_handle_map.end()) { + NPU_CHECK_ERROR(c10_npu::acl::AclrtIpcMemClose(it->second.data)); + ipc_handle_map.erase(it); + } + aclrtFree((void *)block->ptr); total_allocated_memory -= block->size; @@ -3178,6 +3310,15 @@ public: return device_allocator[block->device]->getBaseAllocation(block, outSize); } + ShareableHandle shareIpcHandle(void* ptr) override + { + Block* block = get_allocated_block(ptr); + if (!block) { + AT_ERROR("invalid device pointer: ", ptr); + } + return device_allocator[block->device]->shareIpcHandle(block); + } + void recordStream(const c10::DataPtr &ptr, c10_npu::NPUStream stream) override { // Empty tensor's storage().data() might be a null ptr. As there is no @@ -3435,6 +3576,109 @@ public: this->free(ptr); } + std::mutex IpcMutex; + struct MemHandleCacheEntry { + MemHandleCacheEntry( + c10::DeviceIndex device, + std::string& handle, + const DeviceCachingAllocator& allocator) + : device_(device) + { + int type = SHAREABLE_NPU_MALLOC; + std::istringstream ss(handle); + if (handle.size() != ACL_IPC_HANDLE_SIZE) { + auto version = ss.get(); + TORCH_CHECK( + version <= SHAREABLE_HANDLE_VERSION, + "received sharable handle from a future version of torch that this version does not know how to handle", + PTA_ERROR(ErrCode::NOT_SUPPORT)); + type = ss.get(); + } + // otherwise this is coming from an old pytorch where it has to be a raw + // SHAREABLE_NPU_MALLOC + if (type == SHAREABLE_NPU_MALLOC) { + handle_str handle_r; + ss.read(handle_r.data, ACL_IPC_HANDLE_SIZE); + NPU_CHECK_ERROR(c10_npu::acl::AclrtIpcMemImportByKey(&npu_ipc_ptr_, handle_r.data)); + handle_s.assign(handle_r.data, ACL_IPC_HANDLE_SIZE); + } else if (type == SHAREABLE_NPU_EXPANDABLE_SEGMENT) { + expandable_segment_ = + ExpandableSegment::fromShared(device, ss) + .release(); + } else { + TORCH_INTERNAL_ASSERT( + false, "Unexpected or illformed shareable handle type"); + } + } + // this struct expects that clear is explicitly called to + // free resources, because we only want this code running when + // the shared pointer to this entry is destructed, not during + // deinitialization when npu may already have been shutdown. + // This replicates the previous behavior of this map when it + // stored raw npu_ipc_ptr_ handles. + void clear() + { + if (npu_ipc_ptr_) { + c10_npu::NPUGuard device_guard(device_); + NPU_CHECK_ERROR(c10_npu::acl::AclrtIpcMemClose(handle_s.c_str())); + npu_ipc_ptr_ = nullptr; + } + if (expandable_segment_) { + delete expandable_segment_; + expandable_segment_ = nullptr; + } + } + void* ptr() + { + if (npu_ipc_ptr_) { + return npu_ipc_ptr_; + } else { + return expandable_segment_->ptr(); + } + } + c10::DeviceIndex device_; + ExpandableSegment* expandable_segment_{nullptr}; + void* npu_ipc_ptr_{nullptr}; // nullptr if expandable_segment_ is not null + std::weak_ptr wp_; + std::string handle_s; + }; + ska::flat_hash_map ipcMemHandle_to_devptr; + + std::shared_ptr getIpcDevPtr(std::string handle) override + { + std::lock_guard lock(IpcMutex); + + auto iter = ipcMemHandle_to_devptr.find(handle); + if (iter != ipcMemHandle_to_devptr.end()) { + auto devptr = iter->second.wp_.lock(); + TORCH_INTERNAL_ASSERT(devptr, "entry in cache has missing shared_ptr"); + return devptr; + } + int curr_device = 0; + NPU_CHECK_ERROR(c10_npu::GetDevice(&curr_device)); + auto inserted = ipcMemHandle_to_devptr.insert( + iter, + {handle, + MemHandleCacheEntry( + static_cast(curr_device), handle, *device_allocator[curr_device])}); + auto sp = std::shared_ptr( + inserted->second.ptr(), [handle, this](void* ptr) { + std::unique_lock deleter_lock(IpcMutex); + + auto it = ipcMemHandle_to_devptr.find(handle); + TORCH_INTERNAL_ASSERT(it != ipcMemHandle_to_devptr.end()); + auto entry = std::move(it->second); + ipcMemHandle_to_devptr.erase(it); + + // ExpandableSegment synchronizes on destruction in unmapHandles, so + // we need to release the lock first to minimize the performance hit. + deleter_lock.unlock(); + entry.clear(); + }); + inserted->second.wp_ = sp; + return sp; + } + void FreeDeviceCachedMemory(int device) override { device_allocator[device]->emptyCache(device, true); diff --git a/torch_npu/csrc/core/npu/NPUCachingAllocator.h b/torch_npu/csrc/core/npu/NPUCachingAllocator.h index a4e14d2232ab30f7a3cd4e991c904f404b18f6a5..c7082c89044158360f39373593e2deabb658b776 100644 --- a/torch_npu/csrc/core/npu/NPUCachingAllocator.h +++ b/torch_npu/csrc/core/npu/NPUCachingAllocator.h @@ -188,6 +188,11 @@ using OutOfMemoryObserver = std::function; +struct ShareableHandle { + ptrdiff_t offset; + std::string handle; +}; + class NPUAllocator : public c10::Allocator { public: virtual c10::DataPtr allocate_with_aligned(size_t size, size_t aligned) const = 0; @@ -227,6 +232,8 @@ public: " does not yet support checkPoolLiveAllocations. " "If you need it, please file an issue describing your use case.", PTA_ERROR(ErrCode::NOT_SUPPORT)); } + virtual ShareableHandle shareIpcHandle(void* ptr) = 0; + virtual std::shared_ptr getIpcDevPtr(std::string handle) = 0; virtual bool isHistoryEnabled() { TORCH_CHECK( @@ -376,6 +383,16 @@ inline void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) return get()->releasePool(device, mempool_id); } +inline std::shared_ptr getIpcDevPtr(std::string handle) +{ + return get()->getIpcDevPtr(handle); +} + +inline ShareableHandle shareIpcHandle(void* ptr) +{ + return get()->shareIpcHandle(ptr); +} + inline void FreeDeviceCachedMemory(int device) { return get()->FreeDeviceCachedMemory(device); diff --git a/torch_npu/csrc/core/npu/NPUIPCPidManager.cpp b/torch_npu/csrc/core/npu/NPUIPCPidManager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..94bbd2739abc9464bf3303a92ee75a9a1750378c --- /dev/null +++ b/torch_npu/csrc/core/npu/NPUIPCPidManager.cpp @@ -0,0 +1,36 @@ +#include "torch_npu/csrc/core/npu/NPUIPCPidManager.h" +namespace torch_npu { +namespace ipc { + +int32_t* pids = nullptr; +int pid_num = 0; +int capacity = 0; + +void addPid(int pid) +{ + const int requiredCapacity = pid_num + 1; + + if (requiredCapacity > capacity) { + int newCapacity = capacity + 10; + + int32_t* newArray = new int32_t[newCapacity]; + for (int i = 0; i < pid_num; ++i) { + newArray[i] = pids[i]; + } + + delete[] pids; + pids = newArray; + capacity = newCapacity; + } + + pids[pid_num++] = static_cast(pid); +} + +int getPids(int32_t** ret_pids) +{ + *ret_pids = pids; + return pid_num; +} + +} // namespace ipc +} // namespace torch_npu \ No newline at end of file diff --git a/torch_npu/csrc/core/npu/NPUIPCPidManager.h b/torch_npu/csrc/core/npu/NPUIPCPidManager.h new file mode 100644 index 0000000000000000000000000000000000000000..bc5a72cd891c347ddea3e42000a9e5f94e19d735 --- /dev/null +++ b/torch_npu/csrc/core/npu/NPUIPCPidManager.h @@ -0,0 +1,11 @@ +#pragma once +#include + +namespace torch_npu { +namespace ipc { + +void addPid(int pid); +int getPids(int32_t** pids); + +} // namespace ipc +} // namespace torch_npu \ No newline at end of file diff --git a/torch_npu/csrc/core/npu/NPUQueue.cpp b/torch_npu/csrc/core/npu/NPUQueue.cpp index bd29315e057b8e14ee9189bde7c802f3e73558b9..2fa4c4766a940e92e3c6933135bec478cbaa50fe 100644 --- a/torch_npu/csrc/core/npu/NPUQueue.cpp +++ b/torch_npu/csrc/core/npu/NPUQueue.cpp @@ -314,7 +314,7 @@ NPUStatus Repository::MakeSureQueueEmpty(bool check_error) repo_error + ".\n" + "Since the operator is called asynchronously, the stacktrace may be inaccurate. " "If you want to get the accurate stacktrace, " - "pleace set the environment variable ASCEND_LAUNCH_BLOCKING=1.\n" + + "please set the environment variable ASCEND_LAUNCH_BLOCKING=1.\n" + "Note: ASCEND_LAUNCH_BLOCKING=1 will force ops to run in synchronous mode, " "resulting in performance degradation. " "Please unset ASCEND_LAUNCH_BLOCKING in time after debugging." + @@ -490,7 +490,7 @@ void Repository::Enqueue(void *cur_paras) repo_error + ".\n" + "Since the operator is called asynchronously, the stacktrace may be inaccurate. " "If you want to get the accurate stacktrace, " - "pleace set the environment variable ASCEND_LAUNCH_BLOCKING=1.\n" + + "please set the environment variable ASCEND_LAUNCH_BLOCKING=1.\n" + "Note: ASCEND_LAUNCH_BLOCKING=1 will force ops to run in synchronous mode, " "resulting in performance degradation. " "Please unset ASCEND_LAUNCH_BLOCKING in time after debugging." + diff --git a/torch_npu/csrc/core/npu/interface/AclInterface.cpp b/torch_npu/csrc/core/npu/interface/AclInterface.cpp index 39c4b534435ada7975086b92f54928c5312e35b5..c46740b72db4f4e369d8935a5ca45c0d60c588f2 100644 --- a/torch_npu/csrc/core/npu/interface/AclInterface.cpp +++ b/torch_npu/csrc/core/npu/interface/AclInterface.cpp @@ -82,6 +82,13 @@ LOAD_FUNCTION(aclmdlRICaptureTaskUpdateBegin) LOAD_FUNCTION(aclmdlRICaptureTaskUpdateEnd) LOAD_FUNCTION(aclrtHostRegister) LOAD_FUNCTION(aclrtHostUnregister) +LOAD_FUNCTION(aclrtIpcMemGetExportKey) +LOAD_FUNCTION(aclrtIpcMemSetImportPid) +LOAD_FUNCTION(aclrtIpcMemImportByKey) +LOAD_FUNCTION(aclrtIpcMemClose) +LOAD_FUNCTION(aclrtMemExportToShareableHandle) +LOAD_FUNCTION(aclrtMemSetPidToShareableHandle) +LOAD_FUNCTION(aclrtMemImportFromShareableHandle) aclprofStepInfoPtr init_stepinfo() { typedef aclprofStepInfoPtr(*npdInitFunc)(); @@ -929,5 +936,90 @@ aclError AclrtHostUnregister(void *ptr) return func(ptr); } +aclError AclrtIpcMemGetExportKey(void *devPtr, size_t size, char *name, size_t len) +{ + typedef aclError (*AclrtIpcMemGetExportKey)(void *, size_t, char *, size_t); + static AclrtIpcMemGetExportKey func = nullptr; + if (func == nullptr) { + func = (AclrtIpcMemGetExportKey) GET_FUNC(aclrtIpcMemGetExportKey); + } + + TORCH_CHECK(func, "Failed to find function aclrtIpcMemGetExportKey", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(devPtr, size, name, len); +} + +aclError AclrtIpcMemSetImportPid(const char *name, int32_t pid[], int num) +{ + typedef aclError (*AclrtIpcMemSetImportPid)(const char *, int32_t[], int); + static AclrtIpcMemSetImportPid func = nullptr; + if (func == nullptr) { + func = (AclrtIpcMemSetImportPid) GET_FUNC(aclrtIpcMemSetImportPid); + } + + TORCH_CHECK(func, "Failed to find function aclrtIpcMemSetImportPid", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(name, pid, num); +} + +aclError AclrtIpcMemImportByKey(void **devPtr, const char *name) +{ + typedef aclError (*AclrtIpcMemImportByKey)(void **, const char *); + static AclrtIpcMemImportByKey func = nullptr; + if (func == nullptr) { + func = (AclrtIpcMemImportByKey) GET_FUNC(aclrtIpcMemImportByKey); + } + + TORCH_CHECK(func, "Failed to find function aclrtIpcMemImportByKey", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(devPtr, name); +} + +aclError AclrtIpcMemClose(const char *name) +{ + typedef aclError (*AclrtIpcMemClose)(const char *); + static AclrtIpcMemClose func = nullptr; + if (func == nullptr) { + func = (AclrtIpcMemClose) GET_FUNC(aclrtIpcMemClose); + } + + TORCH_CHECK(func, "Failed to find function aclrtIpcMemClose", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(name); +} + +aclError AclrtMemExportToShareableHandle(aclrtDrvMemHandle handle, aclrtMemHandleType handleType, + uint64_t flags, uint64_t *shareableHandle) +{ + typedef aclError (*AclrtMemExportToShareableHandle)(aclrtDrvMemHandle, aclrtMemHandleType, uint64_t, uint64_t *); + static AclrtMemExportToShareableHandle func = nullptr; + if (func == nullptr) { + func = (AclrtMemExportToShareableHandle) GET_FUNC(aclrtMemExportToShareableHandle); + } + + TORCH_CHECK(func, "Failed to find function aclrtMemExportToShareableHandle", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(handle, handleType, flags, shareableHandle); +} + +aclError AclrtMemSetPidToShareableHandle(uint64_t shareableHandle, int32_t *pid, size_t pidNum) +{ + typedef aclError (*AclrtMemSetPidToShareableHandle)(uint64_t, int32_t *, size_t); + static AclrtMemSetPidToShareableHandle func = nullptr; + if (func == nullptr) { + func = (AclrtMemSetPidToShareableHandle) GET_FUNC(aclrtMemSetPidToShareableHandle); + } + + TORCH_CHECK(func, "Failed to find function aclrtMemSetPidToShareableHandle", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(shareableHandle, pid, pidNum); +} + +aclError AclrtMemImportFromShareableHandle(uint64_t shareableHandle, int32_t deviceId, aclrtDrvMemHandle *handle) +{ + typedef aclError (*AclrtMemImportFromShareableHandle)(uint64_t, int32_t, aclrtDrvMemHandle *); + static AclrtMemImportFromShareableHandle func = nullptr; + if (func == nullptr) { + func = (AclrtMemImportFromShareableHandle) GET_FUNC(aclrtMemImportFromShareableHandle); + } + + TORCH_CHECK(func, "Failed to find function aclrtMemImportFromShareableHandle", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(shareableHandle, deviceId, handle); +} + } // namespace acl } // namespace c10 diff --git a/torch_npu/csrc/core/npu/interface/AclInterface.h b/torch_npu/csrc/core/npu/interface/AclInterface.h index fe567a77aeb7aef32d49eaeae6fd4cc0959740c8..efea0017670180af487036f6c38abd261e2bc2d6 100644 --- a/torch_npu/csrc/core/npu/interface/AclInterface.h +++ b/torch_npu/csrc/core/npu/interface/AclInterface.h @@ -228,5 +228,20 @@ aclError AclrtHostRegister(void *ptr, uint64_t size, aclrtHostRegisterType type, */ aclError AclrtHostUnregister(void *ptr); +aclError AclrtIpcMemGetExportKey(void *devPtr, size_t size, char *name, size_t len); + +aclError AclrtIpcMemSetImportPid(const char *name, int32_t pid[], int num); + +aclError AclrtIpcMemImportByKey(void **devPtr, const char *name); + +aclError AclrtIpcMemClose(const char *name); + +aclError AclrtMemExportToShareableHandle(aclrtDrvMemHandle handle, aclrtMemHandleType handleType, + uint64_t flags, uint64_t *shareableHandle); + +aclError AclrtMemSetPidToShareableHandle(uint64_t shareableHandle, int32_t *pid, size_t pidNum); + +aclError AclrtMemImportFromShareableHandle(uint64_t shareableHandle, int32_t deviceId, aclrtDrvMemHandle *handle); + } // namespace acl } // namespace c10_npu diff --git a/torch_npu/csrc/core/npu/register/OptionsManager.cpp b/torch_npu/csrc/core/npu/register/OptionsManager.cpp index e15bb200f5e2ceea5c404c480aeed965f3d4f398..c41a42ff9f360cf362c248911a2e375045aed182 100644 --- a/torch_npu/csrc/core/npu/register/OptionsManager.cpp +++ b/torch_npu/csrc/core/npu/register/OptionsManager.cpp @@ -482,10 +482,11 @@ uint32_t OptionsManager::GetAclOpInitMode() const static uint32_t acl_op_init_mode = []() -> uint32_t { char* buf_val = std::getenv("ACL_OP_INIT_MODE"); // Default 0 - int64_t acl_op_init_mode = (buf_val != nullptr) ? strtol(buf_val, nullptr, 10) : 0; + int64_t acl_op_init_mode = (buf_val != nullptr) ? strtol(buf_val, nullptr, 10) : 1; std::unordered_map aclOpInitMode = getAclOpInitMode(); if (aclOpInitMode.find(acl_op_init_mode) == aclOpInitMode.end()) { - TORCH_NPU_WARN_ONCE("Get env ACL_OP_INIT_MODE not in [0, 1, 2], so reset it to the default value 0."); + acl_op_init_mode = 1; + TORCH_NPU_WARN_ONCE("Get env ACL_OP_INIT_MODE not in [0, 1, 2], so reset it to the default value 1."); } return static_cast(acl_op_init_mode); }(); diff --git a/torch_npu/csrc/framework/StorageDescHelper.cpp b/torch_npu/csrc/framework/StorageDescHelper.cpp index eb568a74db28c0aa8b640372b0ed7ba6c103a89b..08a2d603b6cd96acbe4cd0a05435b15fae9d275c 100644 --- a/torch_npu/csrc/framework/StorageDescHelper.cpp +++ b/torch_npu/csrc/framework/StorageDescHelper.cpp @@ -62,9 +62,13 @@ void StorageDescHelper::UpdateDesc(torch_npu::NPUStorageDesc &npuDesc, const c10 } } npuDesc.base_strides_ = new_stride; - // 更新物理内存信息 - npuDesc.storage_sizes_ = FormatHelper::GetStorageSizes(npuDesc); + int NCDHW_OR_NDHWC_DIM = 5; + if ((npuDesc.npu_format_ == ACL_FORMAT_NCDHW || npuDesc.npu_format_ == ACL_FORMAT_NDHWC) && new_size.size() < NCDHW_OR_NDHWC_DIM) { + npuDesc.storage_sizes_ = new_size; + } else { + npuDesc.storage_sizes_ = FormatHelper::GetStorageSizes(npuDesc); + } if (new_data_numel > new_shape_numel) { // Refresh format to base format only when flattening storage data npuDesc.storage_sizes_ = new_size; diff --git a/torch_npu/csrc/ipc/CMakeLists.txt b/torch_npu/csrc/ipc/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..2c70da051f6f729c639eeb418daf0d154e6dc239 --- /dev/null +++ b/torch_npu/csrc/ipc/CMakeLists.txt @@ -0,0 +1,6 @@ +FILE(GLOB _IPC_SRCS *.cpp) + +LIST(APPEND IPC_SRCS ${_IPC_SRCS}) + +# Pass to parent +set(IPC_SRCS ${IPC_SRCS} PARENT_SCOPE) \ No newline at end of file diff --git a/torch_npu/csrc/ipc/NPUIPCTypes.cpp b/torch_npu/csrc/ipc/NPUIPCTypes.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b18b6e2f2e8fb372aa2b91ae02c36779ba5f9335 --- /dev/null +++ b/torch_npu/csrc/ipc/NPUIPCTypes.cpp @@ -0,0 +1,252 @@ +#include +#include +#include +#include +#include +#include +#include "torch_npu/csrc/core/npu/NPUGuard.h" +#include "torch_npu/csrc/ipc/NPUIPCTypes.h" + +#include "third_party/acl/inc/acl/acl_base.h" +#include "third_party/acl/inc/acl/acl_rt.h" + +namespace torch_npu { +namespace ipc { + +namespace { + +void warnProducerTerminatedBeforeSharedTensorsReleased() +{ + static bool warned = false; + if (!warned) { + LOG(WARNING) + << "Producer process has been terminated before all shared NPU tensors released. See Note [Sharing NPU tensors]"; + warned = true; + } +} + +struct NpuIPCGlobalEntities { + // This class is used as a singleton (see npu_ipc_global_entities) + // This variable is used to track its lifetime to avoid accessing it + // after it was destroyed which would lead to segmentation faults + // Note that a trvial type is used which doesn't suffer from construction + // and destruction order issues + static bool alive; + + std::mutex ref_counters_mutex_; + std::atomic sync_events_used_{0}; + std::map> + ref_counters_files_; + std::shared_ptr next_available_ref_counters_file_; + NpuIPCSentDataLimbo NpuIPCSentDataLimbo_; + + NpuIPCGlobalEntities() + { + alive = true; + } + + ~NpuIPCGlobalEntities() + { + NpuIPCSentDataLimbo_.collect(); + safe_clean_current_file(); + if (next_available_ref_counters_file_) { + warnProducerTerminatedBeforeSharedTensorsReleased(); + } + alive = false; + } + + void safe_clean_current_file() + { + std::lock_guard lock(ref_counters_mutex_); + if (next_available_ref_counters_file_ && + next_available_ref_counters_file_->offsets_in_use() == 0) { + ref_counters_files_.erase(next_available_ref_counters_file_->handle()); + next_available_ref_counters_file_.reset(); + } + } +}; + +bool NpuIPCGlobalEntities::alive = false; +NpuIPCGlobalEntities npu_ipc_global_entities; + +NpuIPCSentDataLimbo::~NpuIPCSentDataLimbo() +{ + collect(); + if (size() > 0) { + warnProducerTerminatedBeforeSharedTensorsReleased(); + } +} + +bool NpuIPCSentDataLimbo::collect() +{ + bool freed_memory = false; + std::vector> reset_blocks; + { + // Begin critical section to modify shared blocks + std::lock_guard lock(limbo_mutex_); + std::vector> kept_blocks; + for (auto& sd : shared_blocks_) { + if (sd->counter_value() > 0) { + kept_blocks.push_back(std::move(sd)); + } else { + freed_memory = true; + reset_blocks.push_back(std::move(sd)); + } + } + shared_blocks_ = std::move(kept_blocks); + } + // Need to reset blocks out of the critical section here, otherwise it + // deadlocks. + for (auto& sd : reset_blocks) { + sd.reset(); + } + return freed_memory; +} + +void NpuIPCSentDataLimbo::add(std::unique_ptr shared_block) +{ + std::lock_guard lock(limbo_mutex_); + static bool warned = false; + if (shared_blocks_.size() > NPU_IPC_WARN_AFTER_X_BLOCKS_IN_LIMBO && + !warned) { + LOG(WARNING) + << "Producer process tried to deallocate over " + << NPU_IPC_WARN_AFTER_X_BLOCKS_IN_LIMBO + << " memory blocks referred by consumer processes. Deallocation might be significantly slowed down. " + << "We assume it will never going to be the case."; + warned = true; + } + shared_blocks_.push_back(std::move(shared_block)); +} + +uint64_t NpuIPCSentDataLimbo::size() +{ + std::lock_guard lock(limbo_mutex_); + return shared_blocks_.size(); +} + +void NpuIPCSentDataDelete(void* ptr) +{ + std::unique_ptr sent_data( + static_cast(ptr)); + if (!NpuIPCGlobalEntities::alive) { + return; + } + if (sent_data->counter_value() > 0) { + npu_ipc_global_entities.NpuIPCSentDataLimbo_.add(std::move(sent_data)); + } + npu_ipc_global_entities.NpuIPCSentDataLimbo_.collect(); +} + +void ReturnRefCounter(const std::string& handle, uint64_t offset /* unused */) +{ + if (!NpuIPCGlobalEntities::alive) { + return; + } + std::lock_guard lock( + npu_ipc_global_entities.ref_counters_mutex_); + auto& map = npu_ipc_global_entities.ref_counters_files_; + auto it = map.find(handle); + if (it != map.end()) { + it->second->return_offset(offset); + if (it->second->offsets_in_use() == 0 && !it->second->have_offsets()) { + map.erase(handle); + } + } +} + +} // namespace + +NpuIPCSentData::NpuIPCSentData( + std::string handle, + uint64_t offset, + uint64_t* counter_ptr, + at::Device device) + : handle_(std::move(handle)), + offset_(offset), + counter_ptr_(counter_ptr), + device_(device) +{ + if (npu_ipc_global_entities.sync_events_used_.load() < + NPU_IPC_MAXIMUM_EVENTS_TO_USE) { + } else { + auto stream = c10_npu::getCurrentNPUStream(device.index()); + c10_npu::stream_synchronize(stream); + event_ = nullptr; + event_sync_required_ = false; + } +} + +NpuIPCSentData::~NpuIPCSentData() +{ + ReturnRefCounter(handle_, offset_); + try { + if (event_sync_required_) { + } + } catch (...) { /* No throw */ + } +} + +uint64_t NpuIPCSentData::counter_value() +{ + return *counter_ptr_; +} + +at::DataPtr GetNewRefCountedSentData(void* data, at::Device device) +{ + { + std::lock_guard lock( + npu_ipc_global_entities.ref_counters_mutex_); + if (!npu_ipc_global_entities.next_available_ref_counters_file_) { + std::string ref_counter_handle = at::NewProcessWideShmHandle(); + + int flags = + at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_EXCLUSIVE; + at::DataPtr sptr = at::RefcountedMapAllocator::makeDataPtr( + ref_counter_handle.c_str(), + flags, + sizeof(int64_t) * NPU_IPC_REF_COUNTER_FILE_SIZE, + nullptr); + auto rc = std::make_shared( + ref_counter_handle, NPU_IPC_REF_COUNTER_FILE_SIZE, std::move(sptr)); + npu_ipc_global_entities.ref_counters_files_[ref_counter_handle] = rc; + npu_ipc_global_entities.next_available_ref_counters_file_ = rc; + } + } + npu_ipc_global_entities.next_available_ref_counters_file_->set_counter(1); + auto sent_data = new NpuIPCSentData( + npu_ipc_global_entities.next_available_ref_counters_file_->handle(), + npu_ipc_global_entities.next_available_ref_counters_file_->get_offset(), + npu_ipc_global_entities.next_available_ref_counters_file_->counter_ptr(), + device); + + npu_ipc_global_entities.next_available_ref_counters_file_->rotate_offset(); + if (!npu_ipc_global_entities.next_available_ref_counters_file_ + ->have_offsets()) { + npu_ipc_global_entities.next_available_ref_counters_file_.reset(); + } + return at::DataPtr(data, sent_data, NpuIPCSentDataDelete, device); +} + +bool NpuIPCCollect() +{ + if (!NpuIPCGlobalEntities::alive) { + return true; + } + bool freed_memory = npu_ipc_global_entities.NpuIPCSentDataLimbo_.collect(); + if (npu_ipc_global_entities.NpuIPCSentDataLimbo_.size() == 0) { + npu_ipc_global_entities.safe_clean_current_file(); + } + return freed_memory; +} + +} // namespace ipc +} // namespace torch_npu + +namespace c10_npu { +namespace NPUCachingAllocator { + +REGISTER_FREE_MEMORY_CALLBACK("npu_ipc_collect", NpuIPCCollectCallback); + +} // namespace NPUCachingAllocator +} // namespace c10_npu \ No newline at end of file diff --git a/torch_npu/csrc/ipc/NPUIPCTypes.h b/torch_npu/csrc/ipc/NPUIPCTypes.h new file mode 100644 index 0000000000000000000000000000000000000000..5156af2da429aae306f886e7366cd46a82376667 --- /dev/null +++ b/torch_npu/csrc/ipc/NPUIPCTypes.h @@ -0,0 +1,150 @@ +#pragma once +#include + +#include "torch_npu/csrc/core/npu/NPUMacros.h" +#include "torch_npu/csrc/core/npu/NPUFunctions.h" +#include "torch_npu/csrc/core/npu/NPUStream.h" +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" + +namespace torch_npu { +namespace ipc { + +TORCH_NPU_API bool NpuIPCCollect(); + +struct NpuIPCReceivedData final { + NpuIPCReceivedData() = default; + explicit NpuIPCReceivedData(std::shared_ptr shared_ptr) + : shared_ptr_(std::move(shared_ptr)) {} + std::shared_ptr shared_ptr_; +}; + +struct NpuIPCSentData final { + std::string handle_; + uint64_t offset_; + uint64_t* counter_ptr_; // Reference counter shared memory block + at::DataPtr original_ptr_; // Original mem allocation + char* event_; // Sync event + bool event_sync_required_; + at::Device device_; + + NpuIPCSentData( + std::string handle, + uint64_t offset, + uint64_t* counter_ptr, + at::Device device); + ~NpuIPCSentData(); + + uint64_t counter_value(); + std::string handle() + { + return handle_; + } + uint64_t offset() + { + return offset_; + } + void set_original_ptr(at::DataPtr data_ptr) + { + original_ptr_ = std::move(data_ptr); + } +}; + +TORCH_NPU_API at::DataPtr GetNewRefCountedSentData( + void* data, + at::Device device); + +namespace { + +inline constexpr int64_t NPU_IPC_REF_COUNTER_FILE_SIZE = 10000; +inline constexpr int64_t NPU_IPC_WARN_AFTER_X_BLOCKS_IN_LIMBO = 1000; +inline constexpr int64_t NPU_IPC_MAXIMUM_EVENTS_TO_USE = 0; + +// All to be deleted data blocks with non zero reference counter goes there +struct NpuIPCSentDataLimbo final { + ~NpuIPCSentDataLimbo(); + bool collect(); + void add(std::unique_ptr shared_block); + uint64_t size(); + +private: + std::vector> shared_blocks_; + std::mutex limbo_mutex_; +}; + +struct NpuIPCRefCountersFile final { + NpuIPCRefCountersFile( + std::string handle, + uint64_t size, + at::DataPtr data_ptr) + : size_(size), + handle_(std::move(handle)), + refcounted_shared_mem_(std::move(data_ptr)) {} + + uint64_t* counter_ptr() + { + return static_cast(refcounted_shared_mem_.get()) + next_offset_; + } + + void set_counter(uint64_t value) + { + *counter_ptr() = value; + } + + bool have_offsets() + { + return next_offset_ < size_; + } + + bool offsets_in_use() + { + return used_slots_; + } + + uint64_t get_offset() + { + return next_offset_; + } + + void rotate_offset() + { + next_offset_++; + used_slots_++; + } + + void return_offset(uint64_t offset /* unused */) + { + used_slots_--; + } + + std::string handle() + { + return handle_; + } + +private: + uint64_t next_offset_{0}; + uint64_t size_; + uint64_t used_slots_{0}; + std::string handle_; + at::DataPtr refcounted_shared_mem_; +}; + +} // namespace +} // namespace ipc +} // namespace torch_npu + +namespace c10_npu { +namespace NPUCachingAllocator { +namespace { + +class NpuIPCCollectCallback : public FreeMemoryCallback { +public: + bool Execute() override + { + return torch_npu::ipc::NpuIPCCollect(); + } +}; + +} // namespace +} // namespace NPUCachingAllocator +} // namespace c10_npu \ No newline at end of file diff --git a/torch_npu/csrc/ipc/StorageSharing.cpp b/torch_npu/csrc/ipc/StorageSharing.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cd7b9e372a7aca72fb4462fe283cd018cc49b329 --- /dev/null +++ b/torch_npu/csrc/ipc/StorageSharing.cpp @@ -0,0 +1,301 @@ +#ifndef BUILD_LIBTORCH + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "torch_npu/csrc/core/NPUStorageImpl.h" +#include "torch_npu/csrc/core/NPUBridge.h" +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" +#include "torch_npu/csrc/core/npu/NPUGuard.h" + +#include "torch_npu/csrc/ipc/NPUIPCTypes.h" +#include "torch_npu/csrc/ipc/StorageSharing.h" + +#include "third_party/acl/inc/acl/acl_base.h" +#include "third_party/acl/inc/acl/acl_rt.h" + +namespace torch_npu { +namespace reductions { + +static PyObject* THNPStorage_shareNpu(PyObject* self, PyObject* args) +{ + HANDLE_TH_ERRORS + const auto& storage = THPStorage_Unpack(args); + TORCH_CHECK( + storage.device_type() == at::DeviceType::PrivateUse1, + "_share_npu_: only available on NPU.", PTA_ERROR(ErrCode::PARAM)); + c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl(); + + if (storage_impl->received_cuda()) { + AT_ERROR( + "Supported to send NPU tensor received from another process; other is not currently supported. Consider cloning before sending."); + } + + at::DeviceGuard device_guard(storage.device()); + THPObjectPtr tuple(PyTuple_New(8)); + THPObjectPtr device(THPUtils_packInt32(storage.device().index())); + THPObjectPtr _handle(Py_None); + Py_INCREF(Py_None); + THPObjectPtr size_bytes(THPUtils_packUInt64(storage.nbytes())); + THPObjectPtr _offset_bytes(THPUtils_packInt32(0)); + THPObjectPtr _ref_counter(Py_None); + Py_INCREF(Py_None); + THPObjectPtr _ref_counter_offset(THPUtils_packInt32(0)); + THPObjectPtr _event_handle(Py_None); + Py_INCREF(Py_None); + THPObjectPtr _event_sync_required(Py_None); + Py_INCREF(Py_None); + if (storage.data()) { + auto shandle = c10_npu::NPUCachingAllocator::shareIpcHandle(storage.mutable_data()); + _handle = PyBytes_FromStringAndSize( + shandle.handle.c_str(), (Py_ssize_t)shandle.handle.size()); + _offset_bytes = PyLong_FromSsize_t((Py_ssize_t)shandle.offset); + + at::DataPtr sent_data_ptr = torch_npu::ipc::GetNewRefCountedSentData( + storage.mutable_data(), storage.device()); + auto old_data_ptr = storage.set_data_ptr(std::move(sent_data_ptr)); + auto sent_data = + static_cast(storage.data_ptr().get_context()); + sent_data->set_original_ptr(std::move(old_data_ptr)); + _ref_counter = PyBytes_FromString((sent_data->handle()).c_str()); + _ref_counter_offset = THPUtils_packUInt64(sent_data->offset()); + + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + aclrtNotify ipc_event_handle; + + if (sent_data->event_sync_required_) { + // TO BE DONE + } + + _event_handle = PyBytes_FromStringAndSize( + (char*)&ipc_event_handle, sizeof(aclrtNotify)); + _event_sync_required = PyBool_FromLong(sent_data->event_sync_required_); + } + + if (!tuple || !device || !_handle || !size_bytes || !_offset_bytes || + !_event_handle) { + return nullptr; + } + PyTuple_SET_ITEM(tuple.get(), 0, device.release()); + PyTuple_SET_ITEM(tuple.get(), 1, _handle.release()); + // Size(in bytes) of the real storage, note this is not the size of basePtr + // memory block. + PyTuple_SET_ITEM(tuple.get(), 2, size_bytes.release()); + // Offset(in bytes) of the real storage in the basePtr memory block. + // NB: this offset MUST be in bytes instead of numel, since we use + // (storage_handle, offset) + // as key in shared_cache(multiprocessing/reduction.py). + // Offset in numel cannot uniquely represent a storage. + PyTuple_SET_ITEM(tuple.get(), 3, _offset_bytes.release()); + PyTuple_SET_ITEM(tuple.get(), 4, _ref_counter.release()); + PyTuple_SET_ITEM(tuple.get(), 5, _ref_counter_offset.release()); + PyTuple_SET_ITEM(tuple.get(), 6, _event_handle.release()); + PyTuple_SET_ITEM(tuple.get(), 7, _event_sync_required.release()); + return tuple.release(); + END_HANDLE_TH_ERRORS +} + +static PyObject* THNPStorage_releaseIPCCounter(PyObject* _unused, PyObject* args) +{ + HANDLE_TH_ERRORS + TORCH_CHECK(PyTuple_GET_SIZE(args) == 2, "tuple of 2 items expected", PTA_ERROR(ErrCode::PARAM)); + + PyObject* _ref_counter = PyTuple_GET_ITEM(args, 0); + PyObject* _ref_counter_offset = PyTuple_GET_ITEM(args, 1); + if (!(PyBytes_Check(_ref_counter) && THPUtils_checkLong(_ref_counter_offset))) { + THPUtils_invalidArguments( + args, + nullptr, + "_release_ipc_counter in NPU mode", + 1, + "(bytes _ref_counter, int _ref_counter_offset)"); + return nullptr; + } + std::string ref_counter_handle = PyBytes_AS_STRING(_ref_counter); + ptrdiff_t ref_counter_offset = + (ptrdiff_t)THPUtils_unpackLong(_ref_counter_offset); + // We don't want to break existing code, so resource deletion is best + // effort basis. Exception expected if producer process terminated + // before consumer released data. + int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE; + try { + auto sptr = at::RefcountedMapAllocator::makeDataPtr( + ref_counter_handle.c_str(), + flags, + sizeof(int64_t) * torch_npu::ipc::NPU_IPC_REF_COUNTER_FILE_SIZE, + nullptr); + *(static_cast(sptr.get()) + ref_counter_offset) -= 1; + } catch (c10::Error& err) { + // Already warned inside of producer process + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static std::string THNPStorage_bytesAsHandleString(PyObject* handle) +{ + HANDLE_TH_ERRORS + char* buffer = nullptr; + Py_ssize_t handle_size = 0; + if (PyBytes_AsStringAndSize(handle, &buffer, &handle_size) == -1) { + TORCH_CHECK(handle_size == ACL_IPC_HANDLE_SIZE, "incorrect handle", PTA_ERROR(ErrCode::PARAM)); + } + return std::string(buffer, handle_size); + END_HANDLE_TH_ERRORS_RET("") +} + +static PyObject* THNPStorage_newSharedNpu(PyObject* _unused, PyObject* args) +{ + HANDLE_TH_ERRORS + TORCH_CHECK(PyTuple_GET_SIZE(args) == 8, "tuple of 8 items expected", PTA_ERROR(ErrCode::PARAM)); + PyObject* _device = PyTuple_GET_ITEM(args, 0); + PyObject* _handle = PyTuple_GET_ITEM(args, 1); + PyObject* _size_bytes = PyTuple_GET_ITEM(args, 2); + PyObject* _offset_bytes = PyTuple_GET_ITEM(args, 3); + PyObject* _ref_counter = PyTuple_GET_ITEM(args, 4); + PyObject* _ref_counter_offset = PyTuple_GET_ITEM(args, 5); + PyObject* _event_handle = PyTuple_GET_ITEM(args, 6); + PyObject* _event_sync_required = PyTuple_GET_ITEM(args, 7); + if (!(THPUtils_checkLong(_device) && THPUtils_checkLong(_size_bytes) && + PyBytes_Check(_handle) && PyBytes_Check(_ref_counter) && + PyBytes_Check(_event_handle) && THPUtils_checkLong(_offset_bytes) && + THPUtils_checkLong(_ref_counter_offset) && + PyBool_Check(_event_sync_required))) { + THPUtils_invalidArguments( + args, + nullptr, + "_new_shared in NPU mode", + 1, + "(int device, bytes handle, int storage_size_bytes, int storage_offset_bytes, bytes _ref_counter, int _ref_counter_offset, bytes event_handle, bool event_sync_required)"); + return nullptr; + } + + size_t storage_size = + (size_t)THPUtils_unpackLong(_size_bytes) / sizeof(uint8_t); + ptrdiff_t storage_offset_bytes = + (ptrdiff_t)THPUtils_unpackLong(_offset_bytes); + + const auto device = c10::checked_convert( + THPUtils_unpackLong(_device), "c10::DeviceIndex"); + c10_npu::NPUGuard device_guard(device); + + if (PyObject_IsTrue(_event_sync_required)) { + // TO BE DONE + } + + std::string s_handle = THNPStorage_bytesAsHandleString(_handle); + if (s_handle.empty()) { + return nullptr; + } + std::shared_ptr basePtr = + c10_npu::NPUCachingAllocator::getIpcDevPtr(s_handle); + + // Offset the basePtr to reconstruct the real storage + // devPtr = basePtr + storage_offset + void* devPtr = basePtr.get(); + devPtr = (char*)devPtr + storage_offset_bytes; + + std::string ref_counter_handle = PyBytes_AS_STRING(_ref_counter); + ptrdiff_t ref_counter_offset = + (ptrdiff_t)THPUtils_unpackLong(_ref_counter_offset); + + struct IpcDeleterContext { + std::string ref_counter_handle; + ptrdiff_t ref_counter_offset; + int64_t device; + torch_npu::ipc::NpuIPCReceivedData received_data; + }; + + auto ctx = std::make_unique(); + ctx->ref_counter_handle = std::move(ref_counter_handle); + ctx->ref_counter_offset = ref_counter_offset; + ctx->device = device; + ctx->received_data.shared_ptr_ = std::move(basePtr); + + auto cur_device = c10_npu::current_device(); + c10::DataPtr data_ptr( + devPtr, + ctx.release(), + +[](void* ctx_) { + std::unique_ptr ctx( + static_cast(ctx_)); + + ctx->received_data.shared_ptr_.reset(); + + try { + c10_npu::stream_synchronize( + c10_npu::getCurrentNPUStream(ctx->device)); + } catch (c10::Error& err) { + // Already warned inside of producer process + } + + int flags = + at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE; + try { + auto sptr = at::RefcountedMapAllocator::makeDataPtr( + ctx->ref_counter_handle.c_str(), + flags, + sizeof(int64_t) * torch_npu::ipc::NPU_IPC_REF_COUNTER_FILE_SIZE, + nullptr); + *(static_cast(sptr.get()) + ctx->ref_counter_offset) -= 1; + } catch (c10::Error& err) { + // Already warned inside of producer process + } + }, + at::Device(at::DeviceType::PrivateUse1, cur_device)); + + c10::intrusive_ptr base = c10::make_intrusive( + c10::StorageImpl::use_byte_size_t(), + storage_size, + std::move(data_ptr), + nullptr, + false); + + base->set_resizable(false); + base->set_received_cuda(true); + + return THPStorage_NewWithStorage( + THPStorageClass, + std::move(base), + c10::impl::PyInterpreterStatus::TAGGED_BY_US); + END_HANDLE_TH_ERRORS +} + +static PyObject* THNPStorage_isShared(PyObject* self, PyObject* arg) +{ + const auto& storage = THPStorage_Unpack(self); + if (storage.device_type() == at::kPrivateUse1) { + Py_RETURN_TRUE; + } + if (at::MapAllocator::fromDataPtr(storage.data_ptr()) || + THManagedMapAllocator::fromDataPtr(storage.data_ptr())) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + +static struct PyMethodDef TorchReductionsMethods[] = { + {"_share_npu_", THNPStorage_shareNpu, METH_O, nullptr}, + {"_release_ipc_counter_npu", THNPStorage_releaseIPCCounter, METH_VARARGS, nullptr}, + {"_new_shared_npu", THNPStorage_newSharedNpu, METH_VARARGS, nullptr}, + {"_is_shared", THNPStorage_isShared, METH_O, nullptr}, + {nullptr, nullptr, 0, nullptr}, +}; + +PyMethodDef* reductions_functions() +{ + return TorchReductionsMethods; +} + +} // namespace reductions +} // namespace torch_npu + +#endif \ No newline at end of file diff --git a/torch_npu/csrc/ipc/StorageSharing.h b/torch_npu/csrc/ipc/StorageSharing.h new file mode 100644 index 0000000000000000000000000000000000000000..a38e0c0ad68248ecf542a65e5d3f5bc14cff5903 --- /dev/null +++ b/torch_npu/csrc/ipc/StorageSharing.h @@ -0,0 +1,15 @@ +#ifndef BUILD_LIBTORCH +#pragma once + +#include +#include "torch_npu/csrc/core/npu/NPUMacros.h" + +namespace torch_npu { +namespace reductions { + +TORCH_NPU_API PyMethodDef* reductions_functions(); + +} // namespace reductions +} // namespace torch_npu + +#endif \ No newline at end of file diff --git a/torch_npu/csrc/npu/Module.cpp b/torch_npu/csrc/npu/Module.cpp index ecaff129d60b9ac3b443478d17a386e05ede35ef..09e158364bc43ab3d015d99403dc76f7ead5bef0 100644 --- a/torch_npu/csrc/npu/Module.cpp +++ b/torch_npu/csrc/npu/Module.cpp @@ -27,6 +27,8 @@ #include "torch_npu/csrc/core/npu/NPUStream.h" #include "torch_npu/csrc/core/npu/NPUQueue.h" #include "torch_npu/csrc/core/npu/NPUAffinityController.h" +#include "torch_npu/csrc/core/npu/NPUPeerToPeerAccess.h" +#include "torch_npu/csrc/core/npu/NPUIPCPidManager.h" #include "torch_npu/csrc/core/npu/NPUGuard.h" #include "torch_npu/csrc/core/npu/NpuVariables.h" #include "torch_npu/csrc/core/npu/sys_ctrl/npu_sys_ctrl.h" @@ -275,6 +277,24 @@ void RegisterNpuPluggableAllocator(PyObject* module) std::function func = reinterpret_cast(func_ptr); self.set_erase_stream_fn(func); + }) + .def( + "set_get_device_stats_fn", + [](torch::npu::NPUPluggableAllocator::NPUPluggableAllocator& self, + uint64_t func_ptr) { + using FuncType=c10_npu::NPUCachingAllocator::DeviceStats(int); + std::function func = + reinterpret_cast(func_ptr); + self.set_get_device_stats_fn(func); + }) + .def( + "set_reset_peak_status_fn", + [](torch::npu::NPUPluggableAllocator::NPUPluggableAllocator& self, + uint64_t func_ptr) { + using FuncType = void(int); + std::function func = + reinterpret_cast(func_ptr); + self.set_reset_peak_status_fn(func); }); m.def( @@ -1643,6 +1663,34 @@ static PyObject* THNPModule_is_gte_cann_version(PyObject* self, PyObject *args) END_HANDLE_TH_ERRORS } +static PyObject* THNPModule_add_ipc_pid(PyObject* self, PyObject *args) +{ + HANDLE_TH_ERRORS + int pid; + if (!PyArg_ParseTuple(args, "i", &pid)) { + throw torch::TypeError("Pybind failed to parse parameters." + PTA_ERROR(ErrCode::TYPE)); + } + torch_npu::ipc::addPid(pid); + + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* THNPModule_add_p2p_access(PyObject* self, PyObject *args) +{ + HANDLE_TH_ERRORS + int src_dev; + int dst_dev; + if (!PyArg_ParseTuple(args, "ii", &src_dev, &dst_dev)) { + throw torch::TypeError("Pybind failed to parse parameters." + PTA_ERROR(ErrCode::TYPE)); + } + bool warning_flag = false; + at_npu::native::NpuP2pCtrl::get_instance().get_p2p_access(src_dev, dst_dev, warning_flag); + + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + static struct PyMethodDef THNPModule_methods[] = { {"_npu_init", (PyCFunction)THNPModule_initExtension, METH_NOARGS, nullptr}, {"_npu_set_run_yet_variable_to_false", (PyCFunction)THNPModule_set_run_yet_variable_to_false_wrap, METH_NOARGS, nullptr}, @@ -1704,6 +1752,8 @@ static struct PyMethodDef THNPModule_methods[] = { {"_npu_clear_fft_plan_cache", (PyCFunction)THNPModule_npu_clear_fft_plan_cache, METH_NOARGS, nullptr}, {"_get_cann_version", (PyCFunction)THNPModule_get_cann_version, METH_O, nullptr}, {"_is_gte_cann_version", (PyCFunction)THNPModule_is_gte_cann_version, METH_VARARGS, nullptr}, + {"_add_ipc_pid", (PyCFunction)THNPModule_add_ipc_pid, METH_VARARGS, nullptr}, + {"_add_p2p_access", (PyCFunction)THNPModule_add_p2p_access, METH_VARARGS, nullptr}, {nullptr}}; TORCH_NPU_API PyMethodDef* THNPModule_get_methods() diff --git a/torch_npu/csrc/npu/NPUPluggableAllocator.cpp b/torch_npu/csrc/npu/NPUPluggableAllocator.cpp index e8e0fd3eeffaebecc6d11e73de73e49f13af7668..660c69a89d2a4a9326bebdc4dd1c317468da128d 100644 --- a/torch_npu/csrc/npu/NPUPluggableAllocator.cpp +++ b/torch_npu/csrc/npu/NPUPluggableAllocator.cpp @@ -74,6 +74,18 @@ void NPUPluggableAllocator::set_erase_stream_fn( erase_stream_fn_ = std::move(erase_stream_fn); } +void NPUPluggableAllocator::set_get_device_stats_fn( + std::function get_device_stats_fn) +{ + get_device_stats_fn_ = std::move(get_device_stats_fn); +} + +void NPUPluggableAllocator::set_reset_peak_status_fn( + std::function reset_peak_status_fn) +{ + reset_peak_status_fn_ = std::move(reset_peak_status_fn); +} + void* NPUPluggableAllocator::malloc( size_t size, int device, @@ -212,8 +224,11 @@ void NPUPluggableAllocator::eraseStream( c10_npu::NPUCachingAllocator::DeviceStats NPUPluggableAllocator::getDeviceStats(int device) { - TORCH_NPU_WARN("NPUPluggableAllocator does not yet support getDeviceStats. " - "If you need it, please file an issue describing your use case."); + if (get_device_stats_fn_) { + return get_device_stats_fn_(device); + } else { + TORCH_CHECK(false, "get_device_stats_fn_ is not define, please set by set_get_device_stats_fn"); + } } void NPUPluggableAllocator::resetAccumulatedStats(int device) @@ -224,8 +239,11 @@ void NPUPluggableAllocator::resetAccumulatedStats(int device) void NPUPluggableAllocator::resetPeakStats(int device) { - TORCH_NPU_WARN("NPUPluggableAllocator does not yet support resetPeakStats. " - "If you need it, please file an issue describing your use case."); + if (reset_peak_status_fn_) { + reset_peak_status_fn_(device); + } else { + TORCH_CHECK(false, "reset_peak_status_fn_ is not define, please set by set_reset_peak_status_fn"); + } } c10_npu::NPUCachingAllocator::SnapshotInfo NPUPluggableAllocator::snapshot() @@ -282,6 +300,24 @@ void NPUPluggableAllocator::copy_data(void* dest, const void* src, std::size_t c { default_copy_data(dest, src, count); } + +std::shared_ptr NPUPluggableAllocator::getIpcDevPtr(std::string handle) +{ + TORCH_NPU_WARN( + "NPUPluggableAllocator does not yet support getIpcDevPtr. " + "If you need it, please file an issue describing your use case."); + auto sp = std::shared_ptr(); + return sp; +} + +c10_npu::NPUCachingAllocator::ShareableHandle NPUPluggableAllocator::shareIpcHandle(void* ptr) +{ + TORCH_NPU_WARN( + "NPUPluggableAllocator does not yet support shareIPcHandle. " + "If you need it, please file an issue describing your use case."); + return c10_npu::NPUCachingAllocator::ShareableHandle{0, nullptr}; +} + void NPUPluggableAllocator::recordHistory( bool enabled, c10_npu::NPUCachingAllocator::CreateContextFn context_recorder, diff --git a/torch_npu/csrc/npu/NPUPluggableAllocator.h b/torch_npu/csrc/npu/NPUPluggableAllocator.h index 3a71319f3c7c4f79bd208206f1543947e64b9b1e..a3691d48eefbaf3743f5ce29a304a0dab3560151 100644 --- a/torch_npu/csrc/npu/NPUPluggableAllocator.h +++ b/torch_npu/csrc/npu/NPUPluggableAllocator.h @@ -45,6 +45,8 @@ struct NPUPluggableAllocator std::function record_stream_fn); void set_erase_stream_fn( std::function erase_stream_fn); + void set_get_device_stats_fn(std::function get_device_stats_fn); + void set_reset_peak_status_fn(std::function reset_peak_status_fn); void* malloc(size_t size, int device, aclrtStream stream); c10::DataPtr allocate(size_t size) override; @@ -81,6 +83,8 @@ struct NPUPluggableAllocator void FreeDeviceCachedMemory(int device) override; std::string name() override; void copy_data(void* dest, const void* src, std::size_t count) const final; + std::shared_ptr getIpcDevPtr(std::string handle) override; + c10_npu::NPUCachingAllocator::ShareableHandle shareIpcHandle(void*) override; void recordHistory( bool enabled, c10_npu::NPUCachingAllocator::CreateContextFn context_recorder, @@ -108,6 +112,8 @@ protected: std::function base_alloc_fn_; std::function record_stream_fn_; std::function erase_stream_fn_; + std::function get_device_stats_fn_; + std::function reset_peak_status_fn_; std::mutex allocator_mutex_; // We do the bookeeping here in order to simplify custom allocators std::unordered_map allocation_metadata_; diff --git a/torch_npu/csrc/npu/memory_snapshot.cpp b/torch_npu/csrc/npu/memory_snapshot.cpp index 47fbf4de6cf5916a4713f9dde961e80fc89c8f74..cc893243a76fc8dd05d60b13e78fe429d7435dcf 100644 --- a/torch_npu/csrc/npu/memory_snapshot.cpp +++ b/torch_npu/csrc/npu/memory_snapshot.cpp @@ -16,7 +16,11 @@ namespace torch_npu { std::shared_ptr gather() { +#if defined(__x86_64__) return torch::CapturedTraceback::gather(true, true, false); +#else + return torch_npu::CapturedTraceback::gather(true, true, false); +#endif } std::shared_ptr gather_with_cpp() diff --git a/torch_npu/multiprocessing/reductions.py b/torch_npu/multiprocessing/reductions.py new file mode 100644 index 0000000000000000000000000000000000000000..cc40949f7933337eaf6a441b688d1e941849ffa2 --- /dev/null +++ b/torch_npu/multiprocessing/reductions.py @@ -0,0 +1,178 @@ +__all__ = ["rebuild_npu_tensor"] + +import multiprocessing +import torch +from torch.multiprocessing.reductions import ( + shared_cache, + rebuild_storage_filename, + rebuild_storage_empty, + rebuild_storage_fd, + StorageWeakRef, + fd_id, + rebuild_tensor, + storage_from_cache, +) + +import torch_npu + + +def rebuild_npu_tensor( + tensor_cls, + tensor_size, + tensor_stride, + tensor_offset, + storage_cls, + dtype, + storage_device, + storage_handle, + storage_size_bytes, + storage_offset_bytes, + requires_grad, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, +): + # If storage_handle is None, storage points to nullptr. + if storage_handle is None or storage_size_bytes == 0: + storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True) + else: + storage = storage_from_cache( + storage_cls, (storage_handle, storage_offset_bytes) + ) + if storage is None: + torch_npu.npu._lazy_init() + storage = storage_cls._new_shared_npu( + storage_device, + storage_handle, + storage_size_bytes, + storage_offset_bytes, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ) + shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef( + storage + ) + else: + # We already ref counting this Storage, but producer needs new ref-counters to be released. + storage_cls._release_ipc_counter_npu( + ref_counter_handle, ref_counter_offset, device=storage_device + ) + + _storage = ( + storage + if isinstance(storage, torch.UntypedStorage) + else storage._untyped_storage + ) + + t = torch._utils._rebuild_tensor( + torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True), + tensor_offset, + tensor_size, + tensor_stride, + ) + + if tensor_cls == torch.nn.parameter.Parameter: + # It is crucial for integer tensors to receive + # the requires_grad=False as an argument in the constructor + t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) + else: + t.requires_grad = requires_grad + + return t + + +def _npu_reduce_tensor(tensor): + storage = tensor._typed_storage() + + if tensor.requires_grad and not tensor.is_leaf: + raise RuntimeError( + "Cowardly refusing to serialize non-leaf tensor which requires_grad, " + "since autograd does not support crossing process boundaries. " + "If you just want to transfer the data, call detach() on the tensor " + "before serializing (e.g., putting it on the queue)." + ) + + torch._namedtensor_internals.check_serializing_named_tensor(tensor) + torch.utils.hooks.warn_if_has_hooks(tensor) + + if storage._untyped_storage.device.type == "npu": + ( + device, + handle, + storage_size_bytes, + storage_offset_bytes, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ) = storage._share_npu_() + tensor_offset = tensor.storage_offset() + shared_cache[handle] = StorageWeakRef(storage) + return ( + rebuild_npu_tensor, + ( + type(tensor), + tensor.size(), + tensor.stride(), + tensor_offset, # tensor offset in its storage + type(storage), + tensor.dtype, + device, + handle, # identifier which NPU allocation is the storage in. + storage_size_bytes, # size(in bytes) of the storage + storage_offset_bytes, # offset(in bytes) of the storage in the NPU allocation + tensor.requires_grad, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ), + ) + + # _backward_hooks purposely omitted here, see Note [Don't serialize hooks] + metadata = ( + tensor.storage_offset(), + tensor.size(), + tensor.stride(), + tensor.requires_grad, + ) + return (rebuild_tensor, (type(tensor), storage, metadata)) + + +def _npu_reduce_storage(storage): + from torch.multiprocessing import get_sharing_strategy + + if storage.is_npu: + raise RuntimeError( + "Cannot pickle NPU storage; try pickling a NPU tensor instead" + ) + elif get_sharing_strategy() == "file_system": + metadata = storage._share_filename_cpu_() + cache_key = metadata[1] + rebuild = rebuild_storage_filename + if isinstance(storage, torch.TypedStorage): + metadata += (storage.dtype,) + storage._shared_incref() + elif storage.size() == 0: + # This is special cased because Empty tensors + # (with size 0) cannot be mmapped. + return (rebuild_storage_empty, (type(storage),)) + else: + fd, size = storage._share_fd_cpu_() + df = multiprocessing.reduction.DupFd(fd) + cache_key = fd_id(fd) + metadata = (df, size) + rebuild = rebuild_storage_fd # type: ignore[assignment] + + shared_cache[cache_key] = StorageWeakRef(storage) + return (rebuild, (type(storage),) + metadata) + + +def _add_reductions_methods(): + torch.multiprocessing.reductions.reduce_tensor = _npu_reduce_tensor + torch.multiprocessing.reductions.reduce_storage = _npu_reduce_storage + + torch.multiprocessing.reductions.init_reductions() \ No newline at end of file diff --git a/torch_npu/npu/_format.py b/torch_npu/npu/_format.py new file mode 100644 index 0000000000000000000000000000000000000000..209078b1b2dd7ae87d70c85974c2a849398730dd --- /dev/null +++ b/torch_npu/npu/_format.py @@ -0,0 +1,39 @@ +from enum import IntEnum + +import torch +import torch_npu + + +class Format(IntEnum): + """NPU storage format enumeration class""" + UNDEFINED = -1 + NCHW = 0 + NHWC = 1 + ND = 2 + NC1HWC0 = 3 + FRACTAL_Z = 4 + NC1HWC0_C04 = 12 + HWCN = 16 + NDHWC = 27 + FRACTAL_NZ = 29 + NCDHW = 30 + NDC1HWC0 = 32 + FRACTAL_Z_3D = 33 + NC = 35 + NCL = 47 + + def __str__(self): + return self.name + + +def _apply_npu_format_patch(): + orig_get_format = torch_npu.get_npu_format + + def patched_get_format(tensor): + """get the Format type of tensor""" + format_int = orig_get_format(tensor) + return Format(format_int) + + torch_npu.get_npu_format = patched_get_format + torch_npu.Format = Format + \ No newline at end of file diff --git a/torch_npu/utils/storage.py b/torch_npu/utils/storage.py index 9304f141bf2fc5475f1594d6aff44aee99fbc289..85a2a402a37c11f81e77155ab9c792c56c80a061 100644 --- a/torch_npu/utils/storage.py +++ b/torch_npu/utils/storage.py @@ -1,4 +1,7 @@ +__all__ = [] + import copy +from typing import Union import torch from torch.storage import _warn_typed_storage_removal @@ -49,6 +52,37 @@ def _deepcopy(self, memo): return self._new_wrapped_storage(copy.deepcopy(self._untyped_storage, memo)) +def _share_npu_(self, *args, **kwargs): + return torch_npu._C._share_npu_(self, *args, **kwargs) + + +def _typed_storage_share_npu_(self, *args, **kwargs): + return self._untyped_storage._share_npu_(*args, **kwargs) + + +def _new_shared_npu(*args, **kwargs): + return torch_npu._C._new_shared_npu(*args, **kwargs) + + +def _typed_storage_new_shared_npu(*args, **kwargs): + return torch.UntypedStorage._new_shared_npu(*args, **kwargs) + + +def _release_ipc_counter_npu(*args, **kwargs): + return torch_npu._C._release_ipc_counter_npu(*args, **kwargs) + + +def _typed_storage_release_ipc_counter_npu(*args, device: Union[str, torch.device] = "npu", **kwargs): + return torch.UntypedStorage._release_ipc_counter_npu(*args, **kwargs) + + def _add_storage_methods(): torch.storage.UntypedStorage.cpu = _cpu torch.storage.TypedStorage._deepcopy = _deepcopy + + setattr(torch.UntypedStorage, "_share_npu_", _share_npu_) + setattr(torch.UntypedStorage, "_new_shared_npu", _new_shared_npu) + setattr(torch.UntypedStorage, "_release_ipc_counter_npu", _release_ipc_counter_npu) + setattr(torch.TypedStorage, "_share_npu_", _typed_storage_share_npu_) + setattr(torch.TypedStorage, "_new_shared_npu", _typed_storage_new_shared_npu) + setattr(torch.TypedStorage, "_release_ipc_counter_npu", _typed_storage_release_ipc_counter_npu) \ No newline at end of file diff --git a/torch_npu/utils/unsupport_api.py b/torch_npu/utils/unsupport_api.py index 61ba27b3a239f000e7d93437add90e65d62884b8..5626e940b6a690e7a74815095c8d51a3fd08dabd 100644 --- a/torch_npu/utils/unsupport_api.py +++ b/torch_npu/utils/unsupport_api.py @@ -6,8 +6,6 @@ value: parent_module(object) """ unsupported_Tensor_api = { - "is_shared": torch.Tensor, - "share_memory_": torch.Tensor } unsupported_nn_api = {