diff --git a/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp index 56c936b8579bf81ddfba1fb58cbe3e25c2e2a59e..5c854c5163eff22f36d3697aa1c5b991731e7404 100644 --- a/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp +++ b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp @@ -2845,6 +2845,8 @@ private: void insert_events(Block *block) { + int pre_device = -1; + NPU_CHECK_ERROR(c10_npu::GetDevice(&pre_device)); aclrtContext compiler_ctx = aclrtContext(); aclError ret_ctx = aclrtGetCurrentContext(&compiler_ctx); NPU_CHECK_ERROR(aclrtSetCurrentContext(c10_npu::GetDeviceContext(block->device))); @@ -2862,7 +2864,9 @@ private: npu_events[stream].emplace_back(std::move(event), block); } if (ret_ctx == ACL_ERROR_NONE) { - NPU_CHECK_ERROR(aclrtSetCurrentContext(compiler_ctx)); + NPU_CHECK_ERROR(aclrtSetCurrentContext(compiler_ctx)); + // Setting context will exchange device implicitly, so we need to reset the cached device here to ensure consistency. + NPU_CHECK_ERROR(c10_npu::SetDevice(pre_device)); } } diff --git a/torch_npu/csrc/core/npu/NPUQueue.cpp b/torch_npu/csrc/core/npu/NPUQueue.cpp index 579514ab37390f36aa208e7711c6fcec131a9f98..525537e349431a6541f945cda8ba6607c071baa6 100644 --- a/torch_npu/csrc/core/npu/NPUQueue.cpp +++ b/torch_npu/csrc/core/npu/NPUQueue.cpp @@ -5,6 +5,7 @@ #include "torch_npu/csrc/framework/utils/NpuUtils.h" #include "torch_npu/csrc/core/npu/NPUFunctions.h" #include "torch_npu/csrc/framework/OpParamMaker.h" +#include "torch_npu/csrc/framework/OpCommand.h" #include "torch_npu/csrc/core/npu/register/OptionsManager.h" #include "torch_npu/csrc/core/npu/NPUEventManager.h" @@ -249,7 +250,7 @@ NPUStatus Repository::MakeSureQueueEmpty(bool check_error) // occur. #ifndef BUILD_LIBTORCH PyThreadState *gilState = nullptr; - if (PyGILState_Check() != 0) { + if (PyGILState_Check() != 0 && g_used_aclop) { gilState = PyEval_SaveThread(); } #endif @@ -531,7 +532,14 @@ void Repository::Enqueue(void *cur_paras) if (IsFullQueue()) { #ifndef BUILD_LIBTORCH // double check the current thread hold a Gil lock - if (PyGILState_Check() != 0) { + // and release the GIL to TE op compiler in case the acl thread deadlock. + // However, this operator could produce another form of deadlock. + // When thread A deconstract a tensor, it will hold the mutex of deviceCachingAllocator and insert an event into the taskqueue. + // If the taskqueue is full, thead A will run into here and release the GIL. + // Once another thread B get GIL and trigger GC, it may deconstract another tensor + // and try to get deviceCachingAllocator's mutex, which would cause another form of deadlock. + // Since the aclop will be deprecated soon, we just add a using-aclop check here to aviod the second case of deadlock. + if (PyGILState_Check() != 0 && g_used_aclop) { Py_BEGIN_ALLOW_THREADS s = eventfd_read(efd_write, &u); Py_END_ALLOW_THREADS } else { diff --git a/torch_npu/csrc/framework/OpCommand.cpp b/torch_npu/csrc/framework/OpCommand.cpp index 6b98651c51dba728c9062a47d777650ae7ac93a6..59022f09e63aeaf76509362c29978fbd6f2a0a30 100644 --- a/torch_npu/csrc/framework/OpCommand.cpp +++ b/torch_npu/csrc/framework/OpCommand.cpp @@ -33,6 +33,8 @@ static std::unordered_map> integral_limits_map {at::ScalarType::Short, {std::numeric_limits::max(), std::numeric_limits::min()}}}; } // namespace +std::atomic g_used_aclop{false}; + namespace at_npu { namespace native { @@ -124,6 +126,7 @@ void OpCommand::Run() { // Check for npu graph if (aclCmd->CheckCustomHandlerNull()) { + g_used_aclop = true; c10_npu::assertNotCapturingAclop(aclCmd->GetName()); } diff --git a/torch_npu/csrc/framework/OpCommand.h b/torch_npu/csrc/framework/OpCommand.h index e60617077c976b0109a37b72c33254a25333a095..f30d9fb4988bfc8e32902b8b2fab783f820b4d54 100644 --- a/torch_npu/csrc/framework/OpCommand.h +++ b/torch_npu/csrc/framework/OpCommand.h @@ -10,6 +10,8 @@ #include "torch_npu/csrc/framework/utils/NpuUtils.h" #include "torch_npu/csrc/framework/utils/NPUDefinition.h" +extern std::atomic g_used_aclop; + namespace at_npu { namespace native { diff --git a/version.txt b/version.txt index 274800e8c4a62f527d1a1f243774c7cec54d8bf0..786a565423575041b6c85b10895ff602567a8f56 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -2.5.1.post2 +2.5.1.post3