diff --git a/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp index 8c2f0e26a25a6ee17e2bd777f89092fd9cf79a6f..b99e97f2f1d8d4b9c88caa5fcff85aa22312020e 100644 --- a/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp +++ b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp @@ -371,29 +371,32 @@ struct ExpandableSegment { ExpandableSegment( int device, std::optional stream, - size_t size) + size_t size, + size_t handleNum = 0) : device_(device), stream_(stream), - max_handles_(0), + max_handles_(handleNum * 2), // 2MB for small pool, 20MB for large pool segment_size_(size) { - size_t device_free; - size_t device_total; - NPU_CHECK_ERROR(aclrtGetMemInfo(ACL_HBM_MEM, &device_free, &device_total)); - // we allocate enough address space for 1 1/8 the total memory on the NPU. - // This allows for some cases where we have to unmap pages earlier in the - // segment to put them at the end. - max_handles_ = numSegments(device_total + device_total / 8); - if (c10_npu::option::OptionsManager::IsHcclZeroCopyEnable()) { - // prevent HCCL reserve virtual address out of memory - // small pool reserve 2G - // non-default stream large pool 10G - auto default_stream = c10_npu::getDefaultNPUStream().stream(false); - if (kSmallBuffer == segment_size_) { - max_handles_ = numSegments(kSmallPoolVirAddrSize); - } else if (default_stream != *stream) { - max_handles_ = numSegments(kLargePoolVirAddrSize); + if (max_handles_ == 0) { + size_t device_free; + size_t device_total; + NPU_CHECK_ERROR(aclrtGetMemInfo(ACL_HBM_MEM, &device_free, &device_total)); + // we allocate enough address space for 1 1/8 the total memory on the NPU. + // This allows for some cases where we have to unmap pages earlier in the + // segment to put them at the end. + max_handles_ = numSegments(device_total + device_total / 8); + if (c10_npu::option::OptionsManager::IsHcclZeroCopyEnable()) { + // prevent HCCL reserve virtual address out of memory + // small pool reserve 2G + // non-default stream large pool 10G + auto default_stream = c10_npu::getDefaultNPUStream().stream(false); + if (kSmallBuffer == segment_size_) { + max_handles_ = numSegments(kSmallPoolVirAddrSize); + } else if (default_stream != *stream) { + max_handles_ = numSegments(kLargePoolVirAddrSize); + } } } @@ -509,7 +512,8 @@ struct ExpandableSegment { auto segment = std::make_unique( device, std::nullopt, - header.segment_size); + header.segment_size, + header.num_handles); for (auto i : c10::irange(header.num_handles)) { (void)i; uint64_t shareableHandle = 0;