From d6e331df1a1af1bed6b5f5a32cc3ada1037f29f2 Mon Sep 17 00:00:00 2001 From: "jiangrongqiang@huawei.com" Date: Fri, 4 Mar 2022 16:39:34 +0800 Subject: [PATCH] add optimize queue code to master --- torch_npu/csrc/InitNpuBindings.cpp | 2 +- torch_npu/csrc/aten/common/CopyKernel.cpp | 6 +- torch_npu/csrc/aten/common/CopyKernelNpu.cpp | 9 +- .../csrc/aten/common/CopyMemoryKernel.cpp | 9 +- torch_npu/csrc/aten/common/ResizeNpu.h | 8 +- .../csrc/core/npu/NPUCachingAllocator.cpp | 57 ++++- torch_npu/csrc/core/npu/NPUCachingAllocator.h | 2 + .../core/npu/THNPUCachingHostAllocator.cpp | 131 +++++++--- .../csrc/core/npu/THNPUCachingHostAllocator.h | 7 +- torch_npu/csrc/framework/NPUDefine.cpp | 95 +++----- torch_npu/csrc/framework/NPUDefine.h | 5 +- torch_npu/csrc/framework/OpCommandBase.h | 10 +- torch_npu/csrc/framework/OpParamMaker.cpp | 228 +++++++++++++----- torch_npu/csrc/framework/OpParamMaker.h | 86 +++++-- .../framework/contiguous/broadcast_opt.cpp | 7 +- .../csrc/framework/utils/CalcuOpUtil.cpp | 61 +++-- torch_npu/csrc/framework/utils/NpuUtils.cpp | 7 +- 17 files changed, 503 insertions(+), 227 deletions(-) diff --git a/torch_npu/csrc/InitNpuBindings.cpp b/torch_npu/csrc/InitNpuBindings.cpp index 7e015fed5a..721d70e146 100644 --- a/torch_npu/csrc/InitNpuBindings.cpp +++ b/torch_npu/csrc/InitNpuBindings.cpp @@ -16,6 +16,7 @@ #include #include #include +#include "torch_npu/csrc/npu/Event.h" #include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include "torch_npu/csrc/framework/graph/execute/GraphExecutor.h" @@ -23,7 +24,6 @@ #include #include "torch_npu/csrc/core/npu/THNPUCachingHostAllocator.h" -#include "torch_npu/csrc/npu/Event.h" #include "torch_npu/csrc/distributed/Init.h" #include "torch_npu/csrc/profiler/init.h" diff --git a/torch_npu/csrc/aten/common/CopyKernel.cpp b/torch_npu/csrc/aten/common/CopyKernel.cpp index 1b6725e841..f495e28bd4 100644 --- a/torch_npu/csrc/aten/common/CopyKernel.cpp +++ b/torch_npu/csrc/aten/common/CopyKernel.cpp @@ -147,8 +147,11 @@ void copy_between_host_and_device( void* src_ptr = src.data_ptr(); int64_t nbytes = dst.numel() * dst.element_size(); c10::npu::NPUStream stream = c10::npu::getCurrentNPUStream(); + at::Tensor tmp = dst.is_npu() ? src : dst; + c10::Storage tmpSt = tmp.storage(); + bool is_pinned = THNPUCachingHostAllocator_isPinndPtr(tmp.data_ptr()); AT_NPU_CHECK( - aclrtMemcpyAsync(dst_ptr, nbytes, src_ptr, nbytes, kind, stream)); + c10::npu::queue::LaunchAsyncCopyTask(dst_ptr, nbytes, src_ptr, nbytes, kind, tmpSt, is_pinned)); if (non_blocking) { NPU_LOGD("non_blocking copy without StreamSynchronize."); @@ -157,6 +160,7 @@ void copy_between_host_and_device( } else { aclError error = aclrtSynchronizeStream(stream); if (error != ACL_ERROR_NONE) { + C10_NPU_SHOW_ERR_MSG(); AT_ERROR("ACL stream synchronize failed, error code:", error); } } diff --git a/torch_npu/csrc/aten/common/CopyKernelNpu.cpp b/torch_npu/csrc/aten/common/CopyKernelNpu.cpp index cb669c4b2d..2fe3560a7d 100644 --- a/torch_npu/csrc/aten/common/CopyKernelNpu.cpp +++ b/torch_npu/csrc/aten/common/CopyKernelNpu.cpp @@ -17,6 +17,7 @@ #include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" #include "torch_npu/csrc/framework/StorageDescHelper.h" #include "torch_npu/csrc/aten/common/InnerNpuNativeFunction.h" +#include namespace at_npu { namespace native { @@ -107,16 +108,14 @@ void copy_d2d_by_memcpy(at::Tensor& dst, const at::Tensor& src, int64_t exceptSi return; } - c10::npu::NPUStream copy_stream = c10::npu::getCurrentNPUStream(); - aclError error = aclrtMemcpyAsync( + aclError error = c10::npu::queue::LaunchAsyncCopyTask( dst.data_ptr(), size * dst.element_size(), src.data_ptr(), size * dst.element_size(), - ACL_MEMCPY_DEVICE_TO_DEVICE, - copy_stream); + ACL_MEMCPY_DEVICE_TO_DEVICE); if (error != ACL_ERROR_NONE) { - AT_ERROR("aclrtMemcpy device to device error."); + AT_ERROR("async copy device to device error."); return; } } diff --git a/torch_npu/csrc/aten/common/CopyMemoryKernel.cpp b/torch_npu/csrc/aten/common/CopyMemoryKernel.cpp index d39c5032e0..5e7dfe3a66 100644 --- a/torch_npu/csrc/aten/common/CopyMemoryKernel.cpp +++ b/torch_npu/csrc/aten/common/CopyMemoryKernel.cpp @@ -20,6 +20,7 @@ #include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" #include "torch_npu/csrc/framework/FormatHelper.h" #include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include #include "third_party/acl/inc/acl/acl.h" @@ -62,18 +63,16 @@ at::Tensor& NPUNativeFunctions::copy_memory_(at::Tensor& self, const at::Tensor& src_size = (src_element > src_storage) ? src_storage : src_element; } - c10::npu::NPUStream stream = c10::npu::getCurrentNPUStream(); - // Designed for the gather of tensors, ignoring npu_format_ and // copying continuous memory between npu tensors. - AT_NPU_CHECK(aclrtMemcpyAsync( + AT_NPU_CHECK(c10::npu::queue::LaunchAsyncCopyTask( self.data_ptr(), dst_size * self.itemsize(), src.data_ptr(), dst_size * self.itemsize(), - ACL_MEMCPY_DEVICE_TO_DEVICE, - stream)); + ACL_MEMCPY_DEVICE_TO_DEVICE)); if (!non_blocking) { + c10::npu::NPUStream stream = c10::npu::getCurrentNPUStream(); AT_NPU_CHECK(aclrtSynchronizeStream(stream)); } return self; diff --git a/torch_npu/csrc/aten/common/ResizeNpu.h b/torch_npu/csrc/aten/common/ResizeNpu.h index 2f24a77e7b..79cdd3405b 100644 --- a/torch_npu/csrc/aten/common/ResizeNpu.h +++ b/torch_npu/csrc/aten/common/ResizeNpu.h @@ -19,7 +19,7 @@ #include #include #include - +#include #include "torch_npu/csrc/framework/utils/NpuUtils.h" #include "torch_npu/csrc/framework/StorageDescHelper.h" @@ -50,14 +50,12 @@ static void storage_resize_npu( copy_size = storage.nbytes(); } if (copy_size > 0) { - c10::npu::NPUStream copy_stream = c10::npu::getCurrentNPUStream(); - aclError error = aclrtMemcpyAsync( + aclError error = c10::npu::queue::LaunchAsyncCopyTask( storage.data(), copy_size, old_data.get(), copy_size, - ACL_MEMCPY_DEVICE_TO_DEVICE, - copy_stream); + ACL_MEMCPY_DEVICE_TO_DEVICE); if (error != ACL_ERROR_NONE) { AT_ERROR("ACL_Memcpy device to device error."); return; diff --git a/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp index cca744252f..ee338d58db 100644 --- a/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp +++ b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp @@ -14,6 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. + #include #include #include @@ -30,6 +31,7 @@ #include #include #include +#include #include "third_party/acl/inc/acl/acl_base.h" #include "third_party/acl/inc/acl/acl_rt.h" @@ -236,6 +238,8 @@ struct THNCachingAllocator { // lock around calls to aclFree (to prevent deadlocks with NCCL) mutable std::mutex npu_free_mutex; + mutable std::mutex recorded_event_mutex; + // cached blocks larger than 1 MB BlockPool large_blocks; @@ -248,6 +252,8 @@ struct THNCachingAllocator { // outstanding acl events std::deque> npu_events; + std::set recorded_events; + THNCachingAllocator() : large_blocks(BlockComparator), small_blocks(BlockComparator) {} @@ -769,6 +775,7 @@ struct THNCachingAllocator { err = aclrtMalloc( devPtr, size, aclrtMemMallocPolicy::ACL_MEM_MALLOC_HUGE_FIRST); if (err != ACL_ERROR_NONE) { + C10_NPU_SHOW_ERR_MSG(); return err; } } @@ -833,6 +840,14 @@ struct THNCachingAllocator { for (auto& e : npu_events) { aclrtEvent event = e.first; + { + std::lock_guard lock(recorded_event_mutex); + auto it = recorded_events.find(event); + if (c10::npu::OptionsManager::CheckQueueEnable() && + it == recorded_events.end()) { + break; + } + } Block* block = e.second; if (device.has_value() && block->device != *device) { remaining_events.push_back(e); @@ -840,8 +855,14 @@ struct THNCachingAllocator { } C10_NPU_CHECK(aclrtSynchronizeEvent(event)); + { + std::lock_guard lock(recorded_event_mutex); + auto it = recorded_events.find(event); + if (it != recorded_events.end()) { + recorded_events.erase(it); + } + } C10_NPU_CHECK(aclrtDestroyEvent(event)); - block->event_count--; if (block->event_count == 0) { free_block(block); @@ -859,6 +880,11 @@ struct THNCachingAllocator { return it->second; } + void insertRecordedEvent(aclrtEvent event) { + std::lock_guard lock(recorded_event_mutex); + recorded_events.insert(event); + } + void insert_events(Block* block) { int prev_device = 0; C10_NPU_CHECK(aclrtGetDevice(&prev_device)); @@ -874,9 +900,9 @@ struct THNCachingAllocator { C10_NPU_CHECK(aclrtSetDevice(it->device_index())); } - aclrtEvent event; - aclrtCreateEvent(&event); - aclrtRecordEvent(event, it->stream()); + aclrtEvent event = nullptr; + C10_NPU_CHECK(c10::npu::acl::AclrtCreateEventWithFlag(&event, ACL_EVENT_TIME_LINE)); + c10::npu::queue::NpuAllocatorLaunchRecordEventTask(event, *it); block->event_count++; npu_events.emplace_back(event, block); @@ -902,6 +928,16 @@ struct THNCachingAllocator { aclrtEvent event = e.first; Block* block = e.second; + { + std::lock_guard lock(recorded_event_mutex); + auto it = recorded_events.begin(); + it = recorded_events.find(event); + if (c10::npu::OptionsManager::CheckQueueEnable() && + it == recorded_events.end()) { + break; + } + } + aclrtEventStatus status = ACL_EVENT_STATUS_RESERVED; aclError err = aclrtQueryEvent(event, &status); if (err != ACL_ERROR_NONE) { @@ -911,7 +947,14 @@ struct THNCachingAllocator { break; } - aclrtDestroyEvent(event); + { + std::lock_guard lock(recorded_event_mutex); + auto it = recorded_events.find(event); + if (it != recorded_events.end()) { + recorded_events.erase(it); + } + } + C10_NPU_CHECK(aclrtDestroyEvent(event)); block->event_count--; if (block->event_count == 0) { @@ -1092,6 +1135,10 @@ std::vector snapshot() { return caching_allocator.snapshot(); } +void NpuAllocatorInsertRecordedEvent(aclrtEvent event) { + return caching_allocator.insertRecordedEvent(event); +} + uint64_t currentMemoryAllocated(int device) { assertValidDevice(device); return caching_allocator.get_stats_for_device(device).amount_allocated; diff --git a/torch_npu/csrc/core/npu/NPUCachingAllocator.h b/torch_npu/csrc/core/npu/NPUCachingAllocator.h index 8b7eed8e82..b5d61c5bd8 100644 --- a/torch_npu/csrc/core/npu/NPUCachingAllocator.h +++ b/torch_npu/csrc/core/npu/NPUCachingAllocator.h @@ -21,6 +21,7 @@ #include #include #include +#include #include @@ -140,5 +141,6 @@ std::mutex* getFreeMutex(); void FreeDeviceCachedMemory(int device); +C10_NPU_API void NpuAllocatorInsertRecordedEvent(aclrtEvent event); } // namespace NPUCachingAllocator } // namespace c10 diff --git a/torch_npu/csrc/core/npu/THNPUCachingHostAllocator.cpp b/torch_npu/csrc/core/npu/THNPUCachingHostAllocator.cpp index 80cbd38604..bc545a61d0 100644 --- a/torch_npu/csrc/core/npu/THNPUCachingHostAllocator.cpp +++ b/torch_npu/csrc/core/npu/THNPUCachingHostAllocator.cpp @@ -18,6 +18,9 @@ #include #include #include +#include "c10/npu/interface/AsyncTaskQueueInterface.h" +#include "c10/npu/interface/AclInterface.h" +#include "c10/npu/OptionsManager.h" #include @@ -72,6 +75,11 @@ struct HostAllocator { // outstanding ACL events std::deque> npu_events; + // record events + std::mutex record_mutex; + std::set complete_events; + + HostAllocator() : available(BlockComparator) {} aclError malloc(void** ptr, size_t size) { @@ -108,38 +116,43 @@ struct HostAllocator { } aclError free(void* ptr) { - std::lock_guard lock(mutex); + c10::SmallVector needClearVec; + { + std::lock_guard lock(mutex); - if (!ptr) { - return ACL_ERROR_NONE; - } + if (!ptr) { + return ACL_ERROR_NONE; + } - // process outstanding npu events which may have occurred - aclError err = processEvents(); - if (err != ACL_ERROR_NONE) { - return err; - } + // process outstanding npu events which may have occurred + aclError err = processEvents(); + if (err != ACL_ERROR_NONE) { + return err; + } - auto it = blocks.find(ptr); - AT_ASSERT(it != blocks.end()); + auto it = blocks.find(ptr); + AT_ASSERT(it != blocks.end()); - Block& block = it->second; - AT_ASSERT(block.allocated); + Block& block = it->second; + AT_ASSERT(block.allocated); - // free (on valid memory) shouldn't fail, so mark unallocated before - // we process the streams. - block.allocated = false; + // free (on valid memory) shouldn't fail, so mark unallocated before + // we process the streams. + block.allocated = false; - // insert NPU events for each stream on which this block was used. This - err = insertEvents(block); - if (err != ACL_ERROR_NONE) { - return err; - } + // insert npu events for each stream on which this block was used. This + err = insertEvents(block, needClearVec); + if (err != ACL_ERROR_NONE) { + return err; + } - if (block.event_count == 0) { - // the block can be re-used if there are no outstanding npu events - available.insert(block); + if (block.event_count == 0) { + // the block can be re-used if there are no outstanding npu events + available.insert(block); + } } + // free pin memory + needClearVec.clear(); return ACL_ERROR_NONE; } @@ -151,6 +164,7 @@ struct HostAllocator { // Sync when host memory is allocated by malloc aclError error = aclrtSynchronizeStream(stream); if (error != ACL_ERROR_NONE) { + C10_NPU_SHOW_ERR_MSG(); AT_ERROR("ACL stream synchronize failed."); return error; } @@ -164,6 +178,33 @@ struct HostAllocator { return ACL_ERROR_NONE; } + bool isPinndPtr(void* ptr) + { + std::lock_guard lock(mutex); + return blocks.find(ptr) != blocks.end(); + } + + void insertCompleteEvent(aclrtEvent event) + { + if (c10::npu::OptionsManager::CheckQueueEnable()) { + std::lock_guard lock(record_mutex); + complete_events.insert(event); + } + } + + bool findAndEraseCompleteEvent(aclrtEvent event) + { + if (c10::npu::OptionsManager::CheckQueueEnable()) { + std::lock_guard lock(record_mutex); + auto it = complete_events.find(event); + if (it == complete_events.end()) { + return false; + } + complete_events.erase(it); + } + return true; + } + aclError processEvents() { // Process outstanding npuEvents. Events that are completed are removed // from the queue, and the 'event_count' for the corresponding allocation @@ -173,15 +214,27 @@ struct HostAllocator { while (!npu_events.empty()) { auto& e = npu_events.front(); aclrtEvent event = e.first; - aclrtEventStatus status = ACL_EVENT_STATUS_COMPLETE; - aclError err = aclrtQueryEvent(event, &status); - if (status == ACL_EVENT_STATUS_NOT_READY) { + // when TASK_QUEUE_ENABLE is set, pytorch thread can destroy event + // after acl thread has launched record event task + if (!findAndEraseCompleteEvent(event)) { break; - } else if (err != ACL_ERROR_NONE) { + } + aclrtEventStatus status = ACL_EVENT_STATUS_RESERVED; + aclError err = aclrtQueryEvent(event, &status); + if (err != ACL_ERROR_NONE) { + C10_NPU_SHOW_ERR_MSG(); + insertCompleteEvent(event); return err; } + if (status != ACL_EVENT_STATUS_COMPLETE) { + insertCompleteEvent(event); + break; + } + err = aclrtDestroyEvent(event); if (err != ACL_ERROR_NONE) { + C10_NPU_SHOW_ERR_MSG(); + insertCompleteEvent(event); return err; } @@ -204,6 +257,7 @@ struct HostAllocator { Block& block = blocks.at(it->second); if (!block.allocated) { if (aclrtDestroyEvent(event) != ACL_ERROR_NONE) { + C10_NPU_SHOW_ERR_MSG(); NPU_LOGW("destory acl event fail"); } block.event_count--; @@ -231,7 +285,7 @@ struct HostAllocator { } } - aclError insertEvents(Block& block) { + aclError insertEvents(Block& block, c10::SmallVector& needClearVec) { aclError err = ACL_ERROR_NONE; int prev_device = 0; @@ -246,21 +300,24 @@ struct HostAllocator { if (ret != ACL_ERROR_NONE) { err = aclrtSetDevice(it->device_index()); if (err != ACL_ERROR_NONE) { + C10_NPU_SHOW_ERR_MSG(); break; } } else if (pre_device != it->device_index()) { err = aclrtSetDevice(it->device_index()); if (err != ACL_ERROR_NONE) { + C10_NPU_SHOW_ERR_MSG(); break; } } aclrtEvent event = nullptr; - err = aclrtCreateEvent(&event); - if (err != ACL_ERROR_NONE) + err = c10::npu::acl::AclrtCreateEventWithFlag(&event, ACL_EVENT_TIME_LINE); + if (err != ACL_ERROR_NONE) { + C10_NPU_SHOW_ERR_MSG(); break; - - err = aclrtRecordEvent(event, it->stream()); + } + err = c10::npu::queue::HostAllocatorLaunchRecordEventTask(event, *it, needClearVec); if (err != ACL_ERROR_NONE) break; @@ -288,6 +345,14 @@ aclError THNPUCachingHostAllocator_recordEvent( return allocator.recordEvent(ptr, stream); } +void THNPUCachingHostAllocator_insertCompleteEvent(aclrtEvent event) { + return allocator.insertCompleteEvent(event); +} + +bool THNPUCachingHostAllocator_isPinndPtr(void* ptr) { + return allocator.isPinndPtr(ptr); +} + void THNPUCachingHostAllocator_emptyCache() { allocator.emptyCache(); } diff --git a/torch_npu/csrc/core/npu/THNPUCachingHostAllocator.h b/torch_npu/csrc/core/npu/THNPUCachingHostAllocator.h index 2874da1d7d..f7a6826260 100644 --- a/torch_npu/csrc/core/npu/THNPUCachingHostAllocator.h +++ b/torch_npu/csrc/core/npu/THNPUCachingHostAllocator.h @@ -17,12 +17,15 @@ #include #include #include +#include C10_NPU_API c10::Allocator* getTHNPUCachingHostAllocator(void); -C10_NPU_API aclError -THNPUCachingHostAllocator_recordEvent(void* ptr, at::npu::NPUStream stream); +C10_NPU_API aclError THNPUCachingHostAllocator_recordEvent(void* ptr, at::npu::NPUStream stream); +C10_NPU_API void THNPUCachingHostAllocator_insertCompleteEvent(aclrtEvent event); + +C10_NPU_API bool THNPUCachingHostAllocator_isPinndPtr(void* ptr); // Releases cached pinned memory allocations via npuHostFree C10_NPU_API void THNPUCachingHostAllocator_emptyCache(void); diff --git a/torch_npu/csrc/framework/NPUDefine.cpp b/torch_npu/csrc/framework/NPUDefine.cpp index 62b3501b0e..daf77bc9ba 100644 --- a/torch_npu/csrc/framework/NPUDefine.cpp +++ b/torch_npu/csrc/framework/NPUDefine.cpp @@ -41,46 +41,36 @@ namespace at_npu void ExecuteParas::Copy(ExecuteParas &other) { - auto srcPtr = &other; - this->opType = srcPtr->opType; - this->attrInfo = srcPtr->attrInfo; - this->paras = srcPtr->paras; - this->attr = srcPtr->attr; - this->constParams = srcPtr->constParams; - this->hostMemory = srcPtr->hostMemory; + this->opType = other.opType; + this->attrInfo = other.attrInfo; + this->paras = other.paras; + this->attr = other.attr; + this->constParams = other.constParams; + this->hostMemory = other.hostMemory; + this->isFuzzy = other.isFuzzy; + this->isCompiling = other.isCompiling; } - NPUStatus DestroyAclParams(ACL_PARAMS ¶ms) + void ExecuteParas::CopyEx(ExecuteParas& other) { - if (params.input_num != 0) - { - if (params.input_desc != nullptr) - { - for (int i = 0; i < params.input_num; ++i) - { + this->paras = other.paras; + this->attr = other.attr; + this->constParams = other.constParams; + this->isCompiling = other.isCompiling; + } + + NPUStatus DestroyAclParams(ACL_PARAMS& params) + { + if (params.input_num != 0) { + if (params.input_desc != nullptr) { + for (int i = 0; i < params.input_num; ++i) { aclDestroyTensorDesc(params.input_desc[i]); } - delete[] params.input_desc; - params.input_desc = nullptr; } - if (params.inputDims != nullptr) - { - delete[] params.inputDims; - params.inputDims = nullptr; - } - if (params.inputFormats != nullptr) - { - delete[] params.inputFormats; - params.inputFormats = nullptr; - } - if (params.input_data_buf != nullptr) - { - for (int i = 0; i < params.input_num; ++i) - { + if (params.input_data_buf != nullptr) { + for (int i = 0; i < params.input_num; ++i) { C10_NPU_CHECK(aclDestroyDataBuffer(params.input_data_buf[i])); } - delete[] params.input_data_buf; - params.input_data_buf = nullptr; } params.input_num = 0; } @@ -92,31 +82,24 @@ namespace at_npu { aclDestroyTensorDesc(params.output_desc[i]); } - delete[] params.output_desc; - params.output_desc = nullptr; } - if (params.outputDims != nullptr) - { - delete[] params.outputDims; - params.outputDims = nullptr; - } - if (params.outputFormats != nullptr) - { - delete[] params.outputFormats; - params.outputFormats = nullptr; - } - if (params.output_data_buf != nullptr) { - for (int i = 0; i < params.output_num; ++i) - { + for (int i = 0; i < params.output_num; ++i) { C10_NPU_CHECK(aclDestroyDataBuffer(params.output_data_buf[i])); } - delete[] params.output_data_buf; - params.output_data_buf = nullptr; } params.output_num = 0; } + free(params.input_desc); + params.input_desc = nullptr; + params.inputDims = nullptr; + params.inputFormats = nullptr; + params.input_data_buf = nullptr; + params.output_desc = nullptr; + params.outputDims = nullptr; + params.outputFormats = nullptr; + params.output_data_buf = nullptr; return SUCCESS; } @@ -126,21 +109,13 @@ namespace at_npu { for (int i = 0; i < params.constNum; ++i) { - if (params.constList[i] != nullptr) - { + if (params.constList[i] != nullptr) { delete[] params.constList[i]; } } - delete[] params.constList; - params.constList = nullptr; - } - - if (params.constIdx != nullptr) - { - delete[] params.constIdx; - params.constIdx = nullptr; } + params.constList = nullptr; + params.constIdx = nullptr; } - } // namespace native } // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/framework/NPUDefine.h b/torch_npu/csrc/framework/NPUDefine.h index 510fea2b3d..a16714e4c4 100644 --- a/torch_npu/csrc/framework/NPUDefine.h +++ b/torch_npu/csrc/framework/NPUDefine.h @@ -100,11 +100,12 @@ namespace at_npu std::string opType; std::string attrInfo; bool isCompiling = false; + bool isFuzzy = false; ACL_PARAMS paras; CONST_PARAMS constParams; const aclopAttr *attr; int64_t constIdx = -1; - c10::SmallVector hostMemory; + c10::SmallVector hostMemory; ExecuteParas( std::string opName, aclopAttr *acl_attr, @@ -117,11 +118,11 @@ namespace at_npu ExecuteParas() = default; void Release(); void Copy(ExecuteParas &other); + void CopyEx(ExecuteParas& other); }; NPUStatus DestroyAclParams(ACL_PARAMS ¶ms); void DestroyConstParams(CONST_PARAMS ¶ms); - } // namespace native } // namespace at_npu diff --git a/torch_npu/csrc/framework/OpCommandBase.h b/torch_npu/csrc/framework/OpCommandBase.h index 49f5a8374e..5143b91213 100644 --- a/torch_npu/csrc/framework/OpCommandBase.h +++ b/torch_npu/csrc/framework/OpCommandBase.h @@ -27,6 +27,7 @@ #include "torch_npu/csrc/core/npu/NPURunMode.h" #include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "torch_npu/csrc/framework/graph/construct/GraphConstructor.h" +#include "c10/npu/interface/AsyncTaskQueueInterface.h" #define IF_GRAPH_MODE_THEN_RUN(...) \ do { \ @@ -222,10 +223,13 @@ public: void Run() { IF_GRAPH_MODE_THEN_RUN(return;) if (torch_npu::option::OptionsManager::CheckQueueEnable()) { - ExecuteParas params; - aclCmd->ExportParams(params); - c10::npu::enCurrentNPUStream(¶ms); + ExecuteParas execParams; + aclCmd->ExportParams(execParams); + c10::npu::queue::QueueParas params(c10::npu::queue::COMPILE_AND_EXECUTE, sizeof(ExecuteParas), &execParams); + c10::SmallVector needClearVec; + c10::npu::enCurrentNPUStream(¶ms, needClearVec); aclCmd->releaseSource(false); + needClearVec.clear(); } else { aclCmd->Run(); aclCmd->releaseSource(); diff --git a/torch_npu/csrc/framework/OpParamMaker.cpp b/torch_npu/csrc/framework/OpParamMaker.cpp index ec43818e3c..48acf5e9ad 100644 --- a/torch_npu/csrc/framework/OpParamMaker.cpp +++ b/torch_npu/csrc/framework/OpParamMaker.cpp @@ -19,50 +19,52 @@ #include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include "torch_npu/csrc/core/npu/register/OptionsManager.h" #include "torch_npu/csrc/framework/aoe/AoeUtils.h" -#include "torch_npu/csrc/framework/utils/NpuFuzzyBlacklist.h" #include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" #include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/framework/interface/EnvVariables.h" #include "torch_npu/csrc/framework/OpParamMaker.h" +#include "torch_npu/csrc/core/npu/THNPUCachingHostAllocator.h" +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" +#include "c10/npu/NPUEventManager.h" +#include "c10/npu/interface/AsyncTaskQueueInterface.h" namespace at_npu { namespace native { - void OpAttrMaker::Set(aclopAttr *attr, string name, bool value) + void OpAttrMaker::Set(aclopAttr *attr, const string &name, bool value) { aclopSetAttrBool(attr, name.c_str(), value); } - void OpAttrMaker::Set(aclopAttr *attr, string name, int64_t value) + void OpAttrMaker::Set(aclopAttr *attr, const string &name, int64_t value) { aclopSetAttrInt(attr, name.c_str(), value); } - void OpAttrMaker::Set(aclopAttr *attr, string name, float value) + void OpAttrMaker::Set(aclopAttr *attr, const string &name, float value) { aclopSetAttrFloat(attr, name.c_str(), value); } - void OpAttrMaker::Set(aclopAttr *attr, string name, string value) + void OpAttrMaker::Set(aclopAttr *attr, const string &name, string value) { aclopSetAttrString(attr, name.c_str(), value.c_str()); } - void OpAttrMaker::Set(aclopAttr *attr, string name, c10::IntArrayRef value) + void OpAttrMaker::Set(aclopAttr *attr, const string &name, c10::IntArrayRef value) { auto vec = value.vec(); aclopSetAttrListInt(attr, name.c_str(), vec.size(), vec.data()); } - void OpAttrMaker::Set(aclopAttr *attr, string name, at::ArrayRef value) + void OpAttrMaker::Set(aclopAttr *attr, const string &name, at::ArrayRef value) { auto vec = value.vec(); aclopSetAttrListFloat(attr, name.c_str(), vec.size(), vec.data()); } - void OpAttrMaker::Set(aclopAttr *attr, string name, c10::Scalar value) + void OpAttrMaker::Set(aclopAttr *attr, const string &name, c10::Scalar value) { float val = CalcuOpUtil::get_scalar_float_value(value); aclopSetAttrFloat(attr, name.c_str(), val); @@ -70,7 +72,7 @@ namespace at_npu void OpAttrMaker::Set( aclopAttr *attr, - string name, + const string &name, at::ArrayRef value) { // Pointer to values of each listInt. @@ -224,40 +226,20 @@ namespace at_npu return ret; } - int ExecFunc(void *in, aclrtStream stream) + int ExecFunc(c10::npu::queue::QueueParas* in, aclrtStream stream) { - auto cur_paras = (ExecuteParas *)in; + auto cur_paras = static_cast(in->paramVal); NPU_LOGD("Op %s Run.", cur_paras->opType.c_str()); aclError ret; bool reset_flag = false; - if (FuzzyCompileBlacklist::GetInstance().IsInBlacklist(cur_paras->opType) && env::CheckFuzzyEnable()) + if (!cur_paras->isFuzzy) { AclopSetCompileFlag(aclOpCompileFlag::ACL_OP_COMPILE_DEFAULT); reset_flag = true; } - int index = 0; - do - { - if (at_npu::native::aoe::aoe_manager().IsAoeEnabled()) { - ret = at_npu::native::AclGenGraphAndDumpForOp( - (cur_paras->opType).c_str(), - cur_paras->paras.input_num, - cur_paras->paras.input_desc, - cur_paras->paras.input_data_buf, - cur_paras->paras.output_num, - cur_paras->paras.output_desc, - cur_paras->paras.output_data_buf, - cur_paras->attr, - ACL_ENGINE_SYS, - at_npu::native::aoe::aoe_manager().GetDumpGraphPath().c_str(), - nullptr); - if (ret != ACL_ERROR_NONE) { - C10_NPU_SHOW_ERR_MSG(); - TORCH_CHECK(false, "In aoe mode, AclGenGraphAndDumpForOp failed!"); - } - } - ret = aclopCompileAndExecute( + if (at_npu::native::aoe::aoe_manager().IsAoeEnabled()) { + ret = at_npu::native::AclGenGraphAndDumpForOp( (cur_paras->opType).c_str(), cur_paras->paras.input_num, cur_paras->paras.input_desc, @@ -267,12 +249,26 @@ namespace at_npu cur_paras->paras.output_data_buf, cur_paras->attr, ACL_ENGINE_SYS, - ACL_COMPILE_SYS, - nullptr, - stream); - ++index; - } while (NpuUtils::IsOomError(ret, index) && (index < NPU_MAX_OP_EXEC_TRY_NUM)); - + at_npu::native::aoe::aoe_manager().GetDumpGraphPath().c_str(), + nullptr); + if (ret != ACL_ERROR_NONE) { + C10_NPU_SHOW_ERR_MSG(); + TORCH_CHECK(false, "In aoe mode, AclGenGraphAndDumpForOp failed!"); + } + } + ret = aclopCompileAndExecute( + (cur_paras->opType).c_str(), + cur_paras->paras.input_num, + cur_paras->paras.input_desc, + cur_paras->paras.input_data_buf, + cur_paras->paras.output_num, + cur_paras->paras.output_desc, + cur_paras->paras.output_data_buf, + cur_paras->attr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + nullptr, + stream); if (reset_flag) { AclopSetCompileFlag(aclOpCompileFlag::ACL_OP_COMPILE_FUZZ); @@ -290,31 +286,155 @@ namespace at_npu return ret; } - void CopyFunc(void *dst, void *src) + int MemcopyAsyncFunc(c10::npu::queue::QueueParas* in, aclrtStream stream) + { + auto cur_paras = static_cast(in->paramVal); + aclError ret = aclrtMemcpyAsync(cur_paras->dst, cur_paras->dstLen, cur_paras->src, + cur_paras->srcLen, cur_paras->kind, stream); + if (ret != ACL_ERROR_NONE) { + C10_NPU_SHOW_ERR_MSG(); + } + return ret; + } + + int RecordEventFunc(c10::npu::queue::QueueParas* in, aclrtStream stream) + { + auto cur_paras = static_cast(in->paramVal); + aclError ret = aclrtRecordEvent(cur_paras->event, stream); + if (ret != ACL_ERROR_NONE) { + C10_NPU_SHOW_ERR_MSG(); + } + + // Temporary modification to avoid problem that + // event must be recorded before query + if (cur_paras->eventAllocatorType == c10::npu::queue::HOST_ALLOCATOR_EVENT) { + THNPUCachingHostAllocator_insertCompleteEvent(cur_paras->event); + } else if (cur_paras->eventAllocatorType == c10::npu::queue::NPU_ALLOCATOR_EVENT) { + c10_npu::NPUCachingAllocator::NpuAllocatorInsertRecordedEvent(cur_paras->event); + } + + return ret; + } + + int WaitEventFunc(c10::npu::queue::QueueParas* in, aclrtStream stream) { + auto cur_paras = static_cast(in->paramVal); + aclError ret = aclrtStreamWaitEvent(stream, cur_paras->event); + if (ret != ACL_ERROR_NONE) { + C10_NPU_SHOW_ERR_MSG(); + } + return ret; + } + + int LazyDestroyEventFunc(c10::npu::queue::QueueParas* in, aclrtStream stream) { + auto cur_paras = static_cast(in->paramVal); + aclError ret = c10::npu::NPUEventManager::GetInstance().LazyDestroy(cur_paras->event); + if (ret != ACL_ERROR_NONE) { + C10_NPU_SHOW_ERR_MSG(); + } + return ret; + } + + size_t GetMaxLen(size_t x, size_t y, size_t z) { - auto dstPtr = (ExecuteParas *)dst; - auto srcPtr = (ExecuteParas *)src; - dstPtr->Copy(*srcPtr); + return x > y ? (x > z ? x : z) : (y > z ? y : z); } - void ReleaseFunc(void *ptr) + void CopyFunc(void* dst, void* src, c10::SmallVector& needClearVec, uint32_t queueLen) { - auto cur_paras = (ExecuteParas *)ptr; - cur_paras->Release(); + auto dstPtr = static_cast(dst); + auto srcPtr = static_cast(src); + dstPtr->paramVal = static_cast(dst) + sizeof(c10::npu::queue::QueueParas); + // pin memory free will add aclrtRecordEvent to queue + // in order to avoid deadlock, pin memory free operation is moved out of the enqueue operation + if (dstPtr->paramType == c10::npu::queue::COMPILE_AND_EXECUTE) { + needClearVec.swap((static_cast(dstPtr->paramVal))->hostMemory); + // string or smallvector of struct is used, deconstructor need be called before memset + (static_cast(dstPtr->paramVal))->~ExecuteParas(); + } else if (dstPtr->paramType == c10::npu::queue::ASYNC_MEMCPY_EX) { + needClearVec.swap((static_cast(dstPtr->paramVal))->pinMem); + // string or smallvector of struct is used, deconstructor need be called before memset + (static_cast(dstPtr->paramVal))->~CopyParas(); + } + dstPtr->paramStream = srcPtr->paramStream; + dstPtr->paramType = srcPtr->paramType; + dstPtr->paramLen = srcPtr->paramLen; + size_t maxSize = GetMaxLen(sizeof(ExecuteParas), sizeof(c10::npu::queue::CopyParas), + sizeof(c10::npu::queue::EventParas)); + memset(dstPtr->paramVal, 0, maxSize); + if (srcPtr->paramType == c10::npu::queue::COMPILE_AND_EXECUTE) { + (static_cast(dstPtr->paramVal))->Copy(*(static_cast(srcPtr->paramVal))); + } else if ((srcPtr->paramType == c10::npu::queue::ASYNC_MEMCPY) || + (srcPtr->paramType == c10::npu::queue::ASYNC_MEMCPY_EX)) { + (static_cast(dstPtr->paramVal))-> + Copy(*(static_cast(srcPtr->paramVal))); + } else { + (static_cast(dstPtr->paramVal))-> + Copy(*(static_cast(srcPtr->paramVal))); + } + } + + void ReleaseFunc(void* ptr, c10::npu::ReleaseQueue& releaseQueue) + { + releaseQueue.PushToReleaseQueue(ptr); } - void *NewFunc(int caption, int &size) + void* NewFunc(int caption, int& size) { - size = sizeof(ExecuteParas); - return (void *)new ExecuteParas[caption]; + size_t maxSize = GetMaxLen(sizeof(ExecuteParas), sizeof(c10::npu::queue::CopyParas), + sizeof(c10::npu::queue::EventParas)); + size = sizeof(c10::npu::queue::QueueParas) + maxSize; + void *ptr = malloc(size * caption); + TORCH_CHECK(ptr != nullptr, "OpCommand new buffer must be not NULL"); + memset(ptr, 0, size * caption); + return ptr; } - void DeleteFunc(void *ptr) + void DeleteFunc(void* ptr) { - delete[](ExecuteParas *) ptr; + free(ptr); + } + + typedef int (*Func)(c10::npu::queue::QueueParas*, aclrtStream); + using AsyncFuncMap = std::map; + AsyncFuncMap funcMap = { + {c10::npu::queue::COMPILE_AND_EXECUTE, ExecFunc}, + {c10::npu::queue::ASYNC_MEMCPY, MemcopyAsyncFunc}, + {c10::npu::queue::ASYNC_MEMCPY_EX, MemcopyAsyncFunc}, + {c10::npu::queue::RECORD_EVENT, RecordEventFunc}, + {c10::npu::queue::WAIT_EVENT, WaitEventFunc}, + {c10::npu::queue::LAZY_DESTROY_EVENT, LazyDestroyEventFunc}, + }; + + int AsncExecFunc(void* data, uint32_t queueLen) { + auto queueParam = static_cast(data); + auto type = queueParam->paramType; + aclrtStream stream = queueParam->paramStream; + auto ret = funcMap[type](queueParam, stream); + return ret; + } + + void CopyReleaseParamFunc(void* dst, void* src) + { + auto dstPtr = static_cast(dst); + auto srcPtr = static_cast(src); + dstPtr->paramType = srcPtr->paramType; + dstPtr->paramVal = static_cast(dst) + sizeof(c10::npu::queue::QueueParas); + if (srcPtr->paramType == c10::npu::queue::COMPILE_AND_EXECUTE) { + (static_cast(dstPtr->paramVal))->CopyEx(*(static_cast(srcPtr->paramVal))); + } + } + + void ReleaseParamFunc(void* ptr) { + auto queueParam = static_cast(ptr); + auto type = queueParam->paramType; + if (type == c10::npu::queue::COMPILE_AND_EXECUTE) { + auto cur_paras = static_cast(queueParam->paramVal); + cur_paras->Release(); + } } - REGISTER_QUEUE_FUNC(ExecFunc, CopyFunc, ReleaseFunc, NewFunc, DeleteFunc) + REGISTER_QUEUE_FUNC(AsncExecFunc, CopyFunc, ReleaseFunc, NewFunc, DeleteFunc, + CopyReleaseParamFunc, ReleaseParamFunc) OpCommandImpls *OpCommandImpls::GetInstance() { diff --git a/torch_npu/csrc/framework/OpParamMaker.h b/torch_npu/csrc/framework/OpParamMaker.h index 4711a9d5f6..15407a8972 100644 --- a/torch_npu/csrc/framework/OpParamMaker.h +++ b/torch_npu/csrc/framework/OpParamMaker.h @@ -21,6 +21,8 @@ #include "third_party/acl/inc/acl/acl_base.h" #include "torch_npu/csrc/framework/interface/AclOpCompileInterface.h" #include "torch_npu/csrc/framework/NPUDefine.h" +#include "torch_npu/csrc/framework/utils/NpuFuzzyBlacklist.h" +#include "torch_npu/csrc/framework/interface/EnvVariables.h" namespace at_npu { @@ -32,16 +34,16 @@ namespace at_npu class OpAttrMaker { public: - static void Set(aclopAttr *attr, string name, bool value); - static void Set(aclopAttr *attr, string name, int64_t value); - static void Set(aclopAttr *attr, string name, float value); - static void Set(aclopAttr *attr, string name, string value); - static void Set(aclopAttr *attr, string name, c10::IntArrayRef value); - static void Set(aclopAttr *attr, string name, at::ArrayRef value); - static void Set(aclopAttr *attr, string name, c10::Scalar value); + static void Set(aclopAttr *attr, const string &name, bool value); + static void Set(aclopAttr *attr, const string &name, int64_t value); + static void Set(aclopAttr *attr, const string &name, float value); + static void Set(aclopAttr *attr, const string &name, string value); + static void Set(aclopAttr *attr, const string &name, c10::IntArrayRef value); + static void Set(aclopAttr *attr, const string &name, at::ArrayRef value); + static void Set(aclopAttr *attr, const string &name, c10::Scalar value); static void Set( aclopAttr *attr, - string name, + const string &name, at::ArrayRef value); }; // class OpAttrMaker @@ -242,7 +244,7 @@ namespace at_npu const at::Tensor &hostTensor) { AddInput(desc, buffer, dim, format); - execParam.hostMem.emplace_back(hostTensor); + execParam.hostMem.emplace_back(hostTensor.storage()); } void AddConst(c10::SmallVector dimList) @@ -271,7 +273,7 @@ namespace at_npu } template - void AddAttr(string attrName, dataType value) + void AddAttr(const string& attrName, dataType value) { InitAttr(); AttrInfoMaker::Add(value, attrInfo); @@ -291,20 +293,49 @@ namespace at_npu int inputNum = execParam.inDesc.size(); int outputNum = execParam.outDesc.size(); int constNum = execParam.constLists.size(); - const int64_t **constListArr = new const int64_t *[constNum]; - const aclTensorDesc **aclTensorInputDescArr = - new const aclTensorDesc *[inputNum]; - const aclTensorDesc **aclTensorOutputDescArr = - new const aclTensorDesc *[outputNum]; - const aclDataBuffer **aclDataInputBuffArr = - new const aclDataBuffer *[inputNum]; - aclDataBuffer **aclDataOutputBuffArr = new aclDataBuffer *[outputNum]; - - int64_t *constIdxArr = new int64_t[constNum]; - int64_t *inputDimsArr = new int64_t[inputNum]; - int64_t *outputDimsArr = new int64_t[outputNum]; - aclFormat *inputFormatsArr = new aclFormat[inputNum]; - aclFormat *outputFormatsArr = new aclFormat[outputNum]; + + size_t inputTensorDescArrLen = inputNum * sizeof(uintptr_t); + size_t inputDataBuffArrLen = inputNum * sizeof(uintptr_t); + size_t inputDimsArrLen = inputNum * sizeof(int64_t); + size_t inputFormatsArrLen = inputNum * sizeof(aclFormat); + + size_t outputTensorDescArrLen = outputNum * sizeof(uintptr_t); + size_t outputDataBuffArrLen = outputNum * sizeof(uintptr_t); + size_t outputDimsArrLen = outputNum * sizeof(int64_t); + size_t outputFormatsArrLen = outputNum * sizeof(aclFormat); + + size_t constListArrLen = constNum * sizeof(uintptr_t); + size_t constIdxArrLen = constNum * sizeof(int64_t); + + size_t totalMemLen = + inputTensorDescArrLen + inputDataBuffArrLen + + inputDimsArrLen + inputFormatsArrLen + + outputTensorDescArrLen + outputDataBuffArrLen + + outputDimsArrLen + outputFormatsArrLen + + constListArrLen + constIdxArrLen; + char* basePtr = static_cast(malloc(totalMemLen)); + AT_ASSERT(basePtr != nullptr); + const aclTensorDesc** aclTensorInputDescArr = reinterpret_cast(basePtr); + basePtr += inputTensorDescArrLen; + const aclDataBuffer** aclDataInputBuffArr = reinterpret_cast(basePtr); + basePtr += inputDataBuffArrLen; + int64_t* inputDimsArr = reinterpret_cast(basePtr); + basePtr += inputDimsArrLen; + aclFormat* inputFormatsArr = reinterpret_cast(basePtr); + basePtr += inputFormatsArrLen; + + const aclTensorDesc** aclTensorOutputDescArr = reinterpret_cast(basePtr); + basePtr += outputTensorDescArrLen; + aclDataBuffer** aclDataOutputBuffArr = reinterpret_cast(basePtr); + basePtr += outputDataBuffArrLen; + int64_t* outputDimsArr = reinterpret_cast(basePtr); + basePtr += outputDimsArrLen; + aclFormat* outputFormatsArr = reinterpret_cast(basePtr); + basePtr += outputFormatsArrLen; + + const int64_t** constListArr = reinterpret_cast(basePtr); + basePtr += constListArrLen; + int64_t* constIdxArr = reinterpret_cast(basePtr); std::copy( execParam.inDesc.begin(), @@ -366,7 +397,10 @@ namespace at_npu params.constParams.constList = constListArr; params.constParams.constIdx = constIdxArr; params.hostMemory = execParam.hostMem; - } + if (!FuzzyCompileBlacklist::GetInstance().IsInBlacklist(opName) && env::CheckFuzzyEnable()) { + params.isFuzzy = true; + } + } void Run(); @@ -437,7 +471,7 @@ namespace at_npu c10::SmallVector outFormats; c10::SmallVector constLists; c10::SmallVector constIdxs; - c10::SmallVector hostMem; + c10::SmallVector hostMem; aclopAttr *attr = nullptr; bool hasAttr = false; diff --git a/torch_npu/csrc/framework/contiguous/broadcast_opt.cpp b/torch_npu/csrc/framework/contiguous/broadcast_opt.cpp index cb919c2227..5e48d062bf 100644 --- a/torch_npu/csrc/framework/contiguous/broadcast_opt.cpp +++ b/torch_npu/csrc/framework/contiguous/broadcast_opt.cpp @@ -14,7 +14,7 @@ // limitations under the License. #include - +#include #include "torch_npu/csrc/framework/contiguous/ContiguousOpt.h" namespace at_npu { @@ -95,11 +95,10 @@ private: temp_src.unsafeGetTensorImpl()->set_sizes_and_strides(src_size, src.strides()); - c10::npu::NPUStream copy_stream = c10::npu::getCurrentNPUStream(); if (temp_src.is_contiguous()) { auto temp_dst = NPUNativeFunctions::npu_broadcast(temp_src, self.sizes()); - aclrtMemcpyAsync(self.data_ptr(), self.nbytes(), temp_dst.data_ptr(), - self.nbytes(), ACL_MEMCPY_DEVICE_TO_DEVICE, copy_stream); + c10::npu::queue::LaunchAsyncCopyTask(self.data_ptr(), self.nbytes(), temp_dst.data_ptr(), + self.nbytes(), ACL_MEMCPY_DEVICE_TO_DEVICE); return true; } return false; diff --git a/torch_npu/csrc/framework/utils/CalcuOpUtil.cpp b/torch_npu/csrc/framework/utils/CalcuOpUtil.cpp index 88d01a2293..ee303fe3cb 100644 --- a/torch_npu/csrc/framework/utils/CalcuOpUtil.cpp +++ b/torch_npu/csrc/framework/utils/CalcuOpUtil.cpp @@ -27,6 +27,7 @@ #include "torch_npu/csrc/framework/utils/NpuFuzzyBlacklist.h" #include "torch_npu/csrc/framework/interface/EnvVariables.h" #include "third_party/acl/inc/acl/acl_base.h" +#include "c10/npu/interface/AsyncTaskQueueInterface.h" namespace at_npu { @@ -182,8 +183,8 @@ namespace at_npu size_t src_size, aclrtMemcpyKind kind) { - AT_NPU_CHECK(aclrtMemcpyAsync( - dst, dst_size, src, src_size, kind, c10::npu::getCurrentNPUStream())); + AT_NPU_CHECK(c10::npu::queue::LaunchAsyncCopyTask( + dst, dst_size, const_cast(src), src_size, kind)); return SUCCESS; } @@ -378,20 +379,40 @@ namespace at_npu int inputNum = input.size(); int outputNum = output.size(); - const aclTensorDesc **aclTensorInputDescArr = - inputNum == 0 ? nullptr : new const aclTensorDesc *[inputNum]; - const aclTensorDesc **aclTensorOutputDescArr = - outputNum == 0 ? nullptr : new const aclTensorDesc *[outputNum]; - - const aclDataBuffer **aclDataInputBuffArr = - inputNum == 0 ? nullptr : new const aclDataBuffer *[inputNum]; - aclDataBuffer **aclDataOutputBuffArr = - outputNum == 0 ? nullptr : new aclDataBuffer *[outputNum]; - - int64_t *inputDimsArr = new int64_t[inputNum]; - int64_t *outputDimsArr = new int64_t[outputNum]; - aclFormat *inputFormatsArr = new aclFormat[inputNum]; - aclFormat *outputFormatsArr = new aclFormat[outputNum]; + size_t inputTensorDescArrLen = inputNum * sizeof(uintptr_t); + size_t inputDataBuffArrLen = inputNum * sizeof(uintptr_t); + size_t inputDimsArrLen = inputNum * sizeof(int64_t); + size_t inputFormatsArrLen = inputNum * sizeof(aclFormat); + + size_t outputTensorDescArrLen = outputNum * sizeof(uintptr_t); + size_t outputDataBuffArrLen = outputNum * sizeof(uintptr_t); + size_t outputDimsArrLen = outputNum * sizeof(int64_t); + size_t outputFormatsArrLen = outputNum * sizeof(aclFormat); + + size_t totalMemLen = + inputTensorDescArrLen + inputDataBuffArrLen + + inputDimsArrLen + inputFormatsArrLen + + outputTensorDescArrLen + outputDataBuffArrLen + + outputDimsArrLen + outputFormatsArrLen; + char* basePtr = static_cast(malloc(totalMemLen)); + AT_ASSERT(basePtr != nullptr); + + const aclTensorDesc** aclTensorInputDescArr = reinterpret_cast(basePtr); + basePtr += inputTensorDescArrLen; + const aclDataBuffer** aclDataInputBuffArr = reinterpret_cast(basePtr); + basePtr += inputDataBuffArrLen; + int64_t* inputDimsArr = reinterpret_cast(basePtr); + basePtr += inputDimsArrLen; + aclFormat* inputFormatsArr = reinterpret_cast(basePtr); + basePtr += inputFormatsArrLen; + + const aclTensorDesc** aclTensorOutputDescArr = reinterpret_cast(basePtr); + basePtr += outputTensorDescArrLen; + aclDataBuffer** aclDataOutputBuffArr = reinterpret_cast(basePtr); + basePtr += outputDataBuffArrLen; + int64_t* outputDimsArr = reinterpret_cast(basePtr); + basePtr += outputDimsArrLen; + aclFormat* outputFormatsArr = reinterpret_cast(basePtr); for (int i = 0; i < inputNum; i++) { @@ -697,7 +718,13 @@ namespace at_npu auto attrRes = CalcuOpUtil::CreateNpuAttrDesc(attrs); cur_paras.attr = std::get<0>(attrRes); cur_paras.attrInfo = std::get<1>(attrRes); - c10::npu::enCurrentNPUStream(&cur_paras); + if (!FuzzyCompileBlacklist::GetInstance().IsInBlacklist(cur_paras.opType) && env::CheckFuzzyEnable()) { + cur_paras.isFuzzy = true; + } + c10::npu::queue::QueueParas params(c10::npu::queue::COMPILE_AND_EXECUTE, sizeof(ExecuteParas), &cur_paras); + c10::SmallVector needClearVec; + c10::npu::enCurrentNPUStream(¶ms, needClearVec); + needClearVec.clear(); return; } diff --git a/torch_npu/csrc/framework/utils/NpuUtils.cpp b/torch_npu/csrc/framework/utils/NpuUtils.cpp index f1e7ea205a..e41a123f11 100644 --- a/torch_npu/csrc/framework/utils/NpuUtils.cpp +++ b/torch_npu/csrc/framework/utils/NpuUtils.cpp @@ -19,6 +19,7 @@ #include "torch_npu/csrc/framework/utils/NpuUtils.h" #include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "c10/npu/interface/AsyncTaskQueueInterface.h" #include "torch_npu/csrc/framework/FormatHelper.h" #include "torch_npu/csrc/framework/StorageDescHelper.h" #include "torch_npu/csrc/framework/utils/KernelNpuOutputSize.h" @@ -210,15 +211,13 @@ namespace at_npu { auto src_desc = src.storage().unsafeGetStorageImpl()->npu_desc_; at::Tensor src_new = OpPreparation::ApplyTensorWithFormat(src_desc.base_sizes_, src.options(), ACL_FORMAT_NC1HWC0); - c10::npu::NPUStream copy_stream = c10::npu::getCurrentNPUStream(); int64_t numel = src_new.numel(); - aclError error = aclrtMemcpyAsync( + aclError error = c10::npu::queue::LaunchAsyncCopyTask( src_new.data_ptr(), numel * src_new.element_size(), (uint8_t *)src.data_ptr() - src.storage_offset() * src.element_size(), numel * src.element_size(), - ACL_MEMCPY_DEVICE_TO_DEVICE, - copy_stream); + ACL_MEMCPY_DEVICE_TO_DEVICE); src_new.set_(src_new.storage(), src.storage_offset(), src.sizes(), src.strides()); src_new.storage().unsafeGetStorageImpl()->npu_desc_.npu_format_ = ACL_FORMAT_NCHW; -- Gitee