diff --git a/test/npu/test_expandable_segments.py b/test/npu/test_expandable_segments.py new file mode 100644 index 0000000000000000000000000000000000000000..7535581550d2687eb935afdf3efadd6ca1c1dd09 --- /dev/null +++ b/test/npu/test_expandable_segments.py @@ -0,0 +1,40 @@ +import os +import gc + +import torch +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests + +os.environ["PYTORCH_NPU_ALLOC_CONF"] = "expandable_segments:True" + + +class Test_expandable_segments(TestCase): + def test_empty_virt_addr_cache(self): + gc.collect() + torch_npu.npu.empty_cache() + prev = 0 + + x = torch.empty((7500, 1024, 1024), device="npu") + del x + last_r = torch_npu.npu.memory_reserved() + + torch_npu.npu.empty_virt_addr_cache() + new_r = torch_npu.npu.memory_reserved() + self.assertEqual(new_r, prev) + self.assertEqual(torch_npu.npu.max_memory_reserved(), last_r) + + # test re-alloc after empty virtual address + try: + y = torch.empty((7500, 1024, 1024), device="npu") + self.assertGreater(torch_npu.npu.memory_allocated(), prev) + finally: + if y is not None: + del y + self.assertEqual(torch_npu.npu.memory_allocated(), prev) + torch_npu.npu.empty_virt_addr_cache() + # empty unmapped physical handles with empty_cache() + torch_npu.npu.empty_cache() + self.assertEqual(torch_npu.npu.memory_reserved(), prev) + +if __name__ == '__main__': + run_tests() diff --git a/test/torch_npu_schema.json b/test/torch_npu_schema.json index ec92aba7b2c0e2e677f56abb2691e2f437d3200e..b4651235816841a6c1313eb8715685e85cf6a01b 100644 --- a/test/torch_npu_schema.json +++ b/test/torch_npu_schema.json @@ -1034,6 +1034,9 @@ "torch_npu.npu.empty_cache": { "signature": "()" }, + "torch_npu.npu.empty_virt_addr_cache": { + "signature": "()" + }, "torch_npu.npu.enable_deterministic_with_backward": { "signature": "(tensor: torch.Tensor)" }, @@ -1376,6 +1379,9 @@ "torch_npu.npu.memory.empty_cache": { "signature": "()" }, + "torch_npu.npu.memory.empty_virt_addr_cache": { + "signature": "()" + }, "torch_npu.npu.memory.get_allocator_backend": { "signature": "() -> str" }, diff --git a/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp index e42e354a8c003dae43fca375a97ec5a11e8a7db4..d04e485aefa8880169ba1f0a636e63b836b947a1 100644 --- a/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp +++ b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp @@ -188,6 +188,8 @@ struct BlockPool { std::set unmapped; const bool is_small; PrivatePool *owner_PrivatePool; + // store unmapped handles + std::vector free_physical_handles_; BlockPool(bool small, PrivatePool *private_pool = nullptr) : blocks(BlockComparatorSize), @@ -404,7 +406,7 @@ struct ExpandableSegment { // returns the actual range mapped, which may be // greater than requested if size is not aligned to segment_size_. // return size of 0 indicates OOM - SegmentRange map(SegmentRange range) + SegmentRange map(SegmentRange range, BlockPool *pool) { auto begin = segmentLeft(range.ptr); auto end = segmentRight(range.ptr + range.size); @@ -418,6 +420,13 @@ struct ExpandableSegment { for (auto i : c10::irange(begin, end)) { TORCH_INTERNAL_ASSERT(!handles_.at(i), PTA_ERROR(ErrCode::VALUE)); aclrtDrvMemHandle handle = nullptr; + if (!pool->free_physical_handles_.empty()) { + ASCEND_LOGD("Remap cached physical handles for block %zu", i); + handle = pool->free_physical_handles_.back(); + pool->free_physical_handles_.pop_back(); + handles_.at(i) = Handle{handle, std::nullopt}; + continue; + } aclrtPhysicalMemProp prop = {}; prop.handleType = ACL_MEM_HANDLE_TYPE_NONE; prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED; @@ -425,6 +434,7 @@ struct ExpandableSegment { prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE; prop.location.id = static_cast(device_); prop.reserve = 0; + ASCEND_LOGD("Alloc memory from physical device for block %zu", i); auto status = c10_npu::acl::AclrtMallocPhysical(&handle, segment_size_, &prop, 0); if (status == ACL_ERROR_RT_MEMORY_ALLOCATION) { for (auto j : c10::irange(begin, i)) { @@ -449,14 +459,14 @@ struct ExpandableSegment { // unmaps all the completely empty segment_size_ segments between // [begin, begin + size), returns the offset where the range begin, // and the actual size unmapped (multiple of segment_size_) - SegmentRange unmap(SegmentRange range) + SegmentRange unmap(SegmentRange range, BlockPool *pool) { auto begin = segmentRight(range.ptr); auto end = segmentLeft(range.ptr + range.size); if (begin >= end) { return SegmentRange{ range.ptr, 0 }; } - unmapHandles(begin, end); + unmapHandles(begin, end, pool); return rangeFromHandles(begin, end); } @@ -563,7 +573,7 @@ private: ASCEND_LOGD("NPUCachingAllocator mapAndSetAccess: segment_size=%zu", segment_size_); } - void unmapHandles(size_t begin, size_t end) + void unmapHandles(size_t begin, size_t end, BlockPool *pool = nullptr) { // note: unlike aclrtFree, MemUnmap and MemRelease do // not appear to synchronize in all cases, so we have to wait for the @@ -595,7 +605,11 @@ private: continue; } } - NPU_CHECK_ERROR(c10_npu::acl::AclrtFreePhysical(h.handle)); + if (!pool) { + NPU_CHECK_ERROR(c10_npu::acl::AclrtFreePhysical(h.handle)); + } else { + pool->free_physical_handles_.push_back(h.handle); + } } ASCEND_LOGD("NPUCachingAllocator unmap: segment_size=%zu", segment_size_); trimHandles(); @@ -1389,7 +1403,7 @@ public: c10_npu::npuSynchronizeDevice(true); } c10_npu::NPUWorkspaceAllocator::emptyCache(device, true); - block_found = (release_cached_blocks(true, context) && alloc_block(params, true, context, lock)); + block_found = (release_cached_blocks(true, context, true) && alloc_block(params, true, context, lock)); } if (!block_found) { @@ -1738,14 +1752,14 @@ public: } /* * returns cached blocks to the system allocator * */ - void emptyCache(int device, bool check_error) + void emptyCache(int device, bool check_error, bool free_physical) { std::shared_ptr context = maybeGatherContext(RecordContext::ALL); // Make sure event deque from taskqueue, then synchronize Event c10_npu::npuSynchronizeDevice(check_error); std::lock_guard lock(mutex); c10_npu::NPUWorkspaceAllocator::emptyCache(device, check_error); - release_cached_blocks(check_error, context); + release_cached_blocks(check_error, context, free_physical); } void buildServerMemMapForHccl(std::shared_ptr hcclComm) @@ -2319,12 +2333,12 @@ private: return candidate; } - bool map_block(Block *to_map, size_t size, const std::shared_ptr &ctx) + bool map_block(Block *to_map, size_t size, const std::shared_ptr &ctx, BlockPool *map_pool) { TORCH_INTERNAL_ASSERT(!to_map->mapped && size <= to_map->size, PTA_ERROR(ErrCode::VALUE)); TORCH_INTERNAL_ASSERT(!to_map->context_when_allocated); // unmapped blocks should not keep // history - auto mapped_range = to_map->expandable_segment_->map(SegmentRange{ to_map->ptr, size }); + auto mapped_range = to_map->expandable_segment_->map(SegmentRange{ to_map->ptr, size }, map_pool); // failed to map the memory if (mapped_range.size == 0) { return false; @@ -2375,7 +2389,7 @@ private: // unmapped -> free -> * // free -> unmapped -> * - if (!candidate->mapped && !map_block(candidate, std::min(candidate->size, size), ctx)) { + if (!candidate->mapped && !map_block(candidate, std::min(candidate->size, size), ctx, pool)) { return nullptr; } TORCH_INTERNAL_ASSERT(candidate->mapped, PTA_ERROR(ErrCode::INTERNAL)); @@ -2388,7 +2402,7 @@ private: if (C10_UNLIKELY(new_candidate == nullptr)) { return nullptr; } - if (!map_block(new_candidate, std::min(remaining, candidate->next->size), ctx)) { + if (!map_block(new_candidate, std::min(remaining, candidate->next->size), ctx, pool)) { return nullptr; } candidate = new_candidate; @@ -2801,21 +2815,21 @@ private: } // npuSynchronizeDevice must be executed before this function can be called - bool release_cached_blocks(bool check_error, const std::shared_ptr &context) + bool release_cached_blocks(bool check_error, const std::shared_ptr &context, bool free_physical) { // First ensure that all blocks that can't currently be allocated due to // outstanding events are returned to the pool. synchronize_and_free_events(check_error, context); // Free all non-split cached blocks - release_blocks(large_blocks, context); - release_blocks(small_blocks, context); + release_blocks(large_blocks, context, free_physical); + release_blocks(small_blocks, context, free_physical); for (auto it = graph_pools_freeable.begin(); it != graph_pools_freeable.end();) { // See notifyCaptureDestroy for the strategy here. TORCH_INTERNAL_ASSERT(it->second->use_count == 0); - release_blocks(it->second->small_blocks, context); - release_blocks(it->second->large_blocks, context); + release_blocks(it->second->small_blocks, context, free_physical); + release_blocks(it->second->large_blocks, context, free_physical); if (it->second->npuMalloc_count == 0) { auto erase_count = graph_pools.erase(it->first); TORCH_INTERNAL_ASSERT(erase_count == 1); @@ -2883,9 +2897,10 @@ private: block = nullptr; } - void unmap_block(Block *block, const std::shared_ptr &context) + void unmap_block(Block *block, const std::shared_ptr &context, bool free_physical) { - auto unmapped = block->expandable_segment_->unmap(SegmentRange{ block->ptr, block->size }); + auto pool = free_physical ? nullptr : block->pool; + auto unmapped = block->expandable_segment_->unmap(SegmentRange{ block->ptr, block->size }, pool); if (unmapped.size == 0) { return; } @@ -2934,7 +2949,7 @@ private: context ? context : block->context_when_segment_allocated); } - void release_blocks(BlockPool &pool, const std::shared_ptr &context) + void release_blocks(BlockPool &pool, const std::shared_ptr &context, bool free_physical) { std::vector to_unmap; // Frees all non-split blocks @@ -2952,11 +2967,19 @@ private: } } for (Block *block : to_unmap) { - unmap_block(block, context); + unmap_block(block, context, free_physical); if (!block->prev && !block->next) { release_expandable_segment(block); } } + // free cached physical handles + if (free_physical) { + while (!pool.free_physical_handles_.empty()) { + aclrtDrvMemHandle handle = pool.free_physical_handles_.back(); + NPU_CHECK_ERROR(c10_npu::acl::AclrtFreePhysical(handle)); + pool.free_physical_handles_.pop_back(); + } + } } EventPool::Event create_event_internal(int idx) @@ -3315,7 +3338,7 @@ public: } } - void emptyCache(bool check_error) override + void emptyCacheImpl(bool check_error, bool free_physical) override { ASCEND_LOGD("Begin empty cache with check_error = %d", check_error); int32_t current_device = 0; @@ -3331,7 +3354,7 @@ public: } else { NPU_CHECK_WARN(c10_npu::SetDevice(device_idx)); } - device_allocator[device_idx]->emptyCache(device_idx, check_error); + device_allocator[device_idx]->emptyCache(device_idx, check_error, free_physical); } if (check_error) { NPU_CHECK_ERROR(c10_npu::MaybeSetDevice(current_device)); @@ -3341,6 +3364,19 @@ public: ASCEND_LOGD("End empty cache with check_error = %d", check_error); } + void emptyCache(bool check_error) override + { + emptyCacheImpl(check_error, true); + } + + void emptyVirtAddrCache(bool check_error) override + { + if (!CachingAllocatorConfig::expandable_segments()) { + AT_ERROR("Unsupported config for empty_virt_addr_cache, please enable expandable_segments."); + } + emptyCacheImpl(check_error, false); + } + void clearIpcHandles() override { std::lock_guard lock(ipcHandleMutex); @@ -3731,7 +3767,7 @@ public: void FreeDeviceCachedMemory(int device) override { - device_allocator[device]->emptyCache(device, true); + device_allocator[device]->emptyCache(device, true, true); } std::string name() override diff --git a/torch_npu/csrc/core/npu/NPUCachingAllocator.h b/torch_npu/csrc/core/npu/NPUCachingAllocator.h index 13c68aa0e3fccf1f6319bc648000b686ab589165..821b29d1e27d38a6b8b207761a9afb0135223e28 100644 --- a/torch_npu/csrc/core/npu/NPUCachingAllocator.h +++ b/torch_npu/csrc/core/npu/NPUCachingAllocator.h @@ -202,7 +202,9 @@ public: virtual void init(int device_count) = 0; virtual bool initialized() = 0; virtual void setMemoryFraction(double fraction, int device) = 0; + virtual void emptyCacheImpl(bool check_error, bool free_physical) = 0; virtual void emptyCache(bool check_error) = 0; + virtual void emptyVirtAddrCache(bool check_error) = 0; virtual void clearIpcHandles() = 0; virtual void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock) = 0; virtual void* getBaseAllocation(void* ptr, size_t* size) = 0; @@ -306,11 +308,21 @@ inline void setMemoryFraction(double fraction, int device) return get()->setMemoryFraction(fraction, device); } +inline void emptyCacheImpl(bool check_error = true, bool free_physical = true) +{ + return get()->emptyCacheImpl(check_error, free_physical); +} + C10_NPU_API inline void emptyCache(bool check_error = true) { return get()->emptyCache(check_error); } +C10_NPU_API inline void emptyVirtAddrCache(bool check_error = true) +{ + return get()->emptyVirtAddrCache(check_error); +} + inline void clearIpcHandles() { return get()->clearIpcHandles(); diff --git a/torch_npu/csrc/npu/Module.cpp b/torch_npu/csrc/npu/Module.cpp index c66f2a8d52392bf0c3f3ba05457c77a5bec06f5e..3c533ac9c4e3f8583b02c026527156809bb49407 100644 --- a/torch_npu/csrc/npu/Module.cpp +++ b/torch_npu/csrc/npu/Module.cpp @@ -1017,6 +1017,14 @@ PyObject* THNPModule_emptyCache(PyObject *_unused, PyObject *noargs) Py_RETURN_NONE; } +PyObject* THNPModule_emptyVirtAddrCache(PyObject *_unused, PyObject *noargs) +{ + HANDLE_TH_ERRORS + c10_npu::NPUCachingAllocator::emptyVirtAddrCache(); + END_HANDLE_TH_ERRORS + Py_RETURN_NONE; +} + PyObject* THNPModule_memoryStats(PyObject *_unused, PyObject *arg) { HANDLE_TH_ERRORS @@ -1860,6 +1868,7 @@ static struct PyMethodDef THNPModule_methods[] = { {"_npu_is_jit_compile_false", (PyCFunction)THNPModule_is_jit_compile_false_wrap, METH_NOARGS, nullptr}, {"_npu_setMemoryFraction", (PyCFunction) THNPModule_setMemoryFraction, METH_VARARGS, nullptr}, {"_npu_emptyCache", (PyCFunction) THNPModule_emptyCache, METH_NOARGS, nullptr}, + {"_npu_emptyVirtAddrCache", (PyCFunction) THNPModule_emptyVirtAddrCache, METH_NOARGS, nullptr}, {"_npu_memoryStats", (PyCFunction) THNPModule_memoryStats, METH_O, nullptr}, {"_npu_resetAccumulatedMemoryStats", (PyCFunction) THNPModule_resetAccumulatedMemoryStats, METH_O, nullptr}, {"_npu_resetPeakMemoryStats", (PyCFunction) THNPModule_resetPeakMemoryStats, METH_O, nullptr}, diff --git a/torch_npu/csrc/npu/NPUPluggableAllocator.cpp b/torch_npu/csrc/npu/NPUPluggableAllocator.cpp index 7610374a3ba35297c97eac5d17dbd11cc3bba0b9..cc4f03f6067081e323c7218e0910472cb20f6ad5 100644 --- a/torch_npu/csrc/npu/NPUPluggableAllocator.cpp +++ b/torch_npu/csrc/npu/NPUPluggableAllocator.cpp @@ -182,6 +182,12 @@ void NPUPluggableAllocator::setMemoryFraction(double fraction, int device) } } +void NPUPluggableAllocator::emptyCacheImpl(bool check_error, bool free_physical) +{ + TORCH_NPU_WARN("NPUPluggableAllocator does not yet support emptyCacheImpl. " + "If you need it, please file an issue describing your use case."); +} + void NPUPluggableAllocator::emptyCache(bool check_error) { if (reset_fn_) { @@ -189,6 +195,12 @@ void NPUPluggableAllocator::emptyCache(bool check_error) } } +void NPUPluggableAllocator::emptyVirtAddrCache(bool check_error) +{ + TORCH_NPU_WARN("NPUPluggableAllocator does not yet support emptyVirtAddrCache. " + "If you need it, please file an issue describing your use case."); +} + void NPUPluggableAllocator::clearIpcHandles() { TORCH_NPU_WARN("NPUPluggableAllocator does not yet support clearIpcHandles. " diff --git a/torch_npu/csrc/npu/NPUPluggableAllocator.h b/torch_npu/csrc/npu/NPUPluggableAllocator.h index 266db02a604c906f0e5a4abf6e07d0f407504613..41e768979980c3e7fac218b176a5c8c5ea700781 100644 --- a/torch_npu/csrc/npu/NPUPluggableAllocator.h +++ b/torch_npu/csrc/npu/NPUPluggableAllocator.h @@ -59,7 +59,9 @@ struct NPUPluggableAllocator void init(int device_count) override; bool initialized() override; void setMemoryFraction(double fraction, int device) override; + void emptyCacheImpl(bool check_error, bool free_physical) override; void emptyCache(bool check_error) override; + void emptyVirtAddrCache(bool check_error) override; void clearIpcHandles() override; void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock) override; void* getBaseAllocation(void* ptr, size_t* size) override; diff --git a/torch_npu/npu/__init__.py b/torch_npu/npu/__init__.py index 75bf03d13a2f6d2ab9b7c5cefe49ee3da86e2a61..33db3e1fa7e6f0c77132ea440af7c8f8a4673284 100644 --- a/torch_npu/npu/__init__.py +++ b/torch_npu/npu/__init__.py @@ -33,6 +33,7 @@ __all__ = [ "caching_allocator_delete", "set_per_process_memory_fraction", "empty_cache", + "empty_virt_addr_cache", "memory_stats", "memory_stats_as_nested_dict", "reset_accumulated_memory_stats", diff --git a/torch_npu/npu/memory.py b/torch_npu/npu/memory.py index 744757878291584084e329741ebf198e1bd2ad05..ad5c816043819e699a0f9f663a0af52e466fb649 100644 --- a/torch_npu/npu/memory.py +++ b/torch_npu/npu/memory.py @@ -20,6 +20,7 @@ __all__ = [ "caching_allocator_delete", "set_per_process_memory_fraction", "empty_cache", + "empty_virt_addr_cache", "memory_stats", "memory_stats_as_nested_dict", "reset_accumulated_memory_stats", @@ -158,6 +159,14 @@ def empty_cache(): torch_npu._C._npu_emptyCache() +def empty_virt_addr_cache(): + r"""Light-weight version of empty_cache(). It only unmaps virtual address, + and store the free physical handles for later malloc. + """ + if is_initialized(): + torch_npu._C._npu_emptyVirtAddrCache() + + def memory_stats(device=None): """Returns a dictionary of NPU memory allocator statistics for a given device.