diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index 64554f108144236a86ef0d37ca6bbab44f4da331..63da6c4dcd5bb60eb1854796112766e4b4c5e1f9 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -98,6 +98,11 @@ apply_class_patches() # NPU exit, need to synchronize devices def _npu_shutdown(): + if torch.npu.is_available() and \ + torch.npu.is_initialized() and \ + torch.distributed.is_available() and \ + torch.distributed.is_initialized(): + torch.distributed.release_process_group() torch_npu._C._npu_shutdown() diff --git a/torch_npu/csrc/distributed/Init.cpp b/torch_npu/csrc/distributed/Init.cpp index e8717648c5ce728b8125839a61240e31588e3ac2..2f16f54fc90a44d9f65f5b5840d22406eb168d79 100644 --- a/torch_npu/csrc/distributed/Init.cpp +++ b/torch_npu/csrc/distributed/Init.cpp @@ -299,7 +299,12 @@ PyObject* c10d_init(PyObject* _unused, PyObject* noargs) { py::arg("rank"), py::arg("size"), py::arg("timeout") = std::chrono::milliseconds( - ::c10d_npu::ProcessGroupHCCL::kProcessGroupHCCLOpTimeoutMillis)); + ::c10d_npu::ProcessGroupHCCL::kProcessGroupHCCLOpTimeoutMillis)) + .def("release_resource", + [](::c10d_npu::ProcessGroupHCCL& pg) { + pg.release_resource(); + }, + py::call_guard()); intrusive_ptr_class_<::c10d_npu::ProcessGroupHCCL::Options>( processGroupHCCL, "Options") diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index 31d394afccb8252302d53b71d739288b85f642c6..a4ae31b4f5a11769683d6c8a8a022f065cbb5897 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -723,4 +723,10 @@ c10::intrusive_ptr ProcessGroupHCCL::recvAnysource( int /* unused */) { throw std::runtime_error("ProcessGroupHCCL does not support recv"); } + +void ProcessGroupHCCL::release_resource() { + c10::npu::npuSynchronizeDevice(); + this->hcclEvents_.clear(); + this->devHCCLCommMap_.clear(); +} } // namespace c10d_npu diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp index 13b1da5eb7b5dabcf2e709a75e30b670ebf2324c..720dcfba91c3bf038e120a359144eb50a445e8c5 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp @@ -248,6 +248,8 @@ public: static const int64_t kProcessGroupHCCLOpTimeoutMillis; + void release_resource(); + protected: // Helper that broadcasts HCCL Master ID to all ranks through the store void broadcastMasterID(HcclRootInfo* hcclID); diff --git a/torch_npu/distributed/distributed_c10d.py b/torch_npu/distributed/distributed_c10d.py index 7eb05691fa7aee4925390ebd6296598409e52e6e..7b7d713dbe4a07e3c48172c5ee59408c24fa2c82 100644 --- a/torch_npu/distributed/distributed_c10d.py +++ b/torch_npu/distributed/distributed_c10d.py @@ -72,7 +72,7 @@ __all__ = [ "isend", "irecv", "send", "recv", "P2POp", "batch_isend_irecv", "broadcast", "all_reduce", "all_reduce_coalesced", "reduce", "all_gather", "all_gather_coalesced", "gather", "scatter", "reduce_scatter", "all_to_all_single", "all_to_all", "barrier", "new_group", "ProcessGroupHCCL", - "_get_default_group" + "_get_default_group", "release_process_group" ] # Some reduce ops are not supported by complex numbers and will result in an error. @@ -1827,3 +1827,7 @@ def new_group(ranks=None, timeout=default_pg_timeout, backend=None): _store_based_barrier(global_rank, default_store, timeout) return pg + +def release_process_group(): + if _default_pg is not None and is_hccl_available(): + _default_pg.release_resource()