diff --git a/torch_npu/csrc/core/npu/NPUFunctions.cpp b/torch_npu/csrc/core/npu/NPUFunctions.cpp index d9c1944978fdc66045f11b5124cf7bec1e2818ff..d523e945ebfb19b34f2ab0c8b2a5541b4fa03eef 100644 --- a/torch_npu/csrc/core/npu/NPUFunctions.cpp +++ b/torch_npu/csrc/core/npu/NPUFunctions.cpp @@ -46,7 +46,7 @@ aclError GetDevice(int32_t *device) *device = local_device; return ACL_ERROR_NONE; } - aclError err = aclrtGetDevice(device); + aclError err = aclrtGetDevice(device); if (err != ACL_ERROR_NONE) { CHECK_AND_THROW_ERROR_WITH_SPECIFIC_MESSAGE(err); } @@ -180,9 +180,13 @@ void device_synchronize() int ExchangeDevice(int device) { + int32_t cur_device = 0; + NPU_CHECK_ERROR_WITHOUT_UCE(GetDevice(&cur_device)); + if (device == cur_device) { + return cur_device; + } NPU_CHECK_ERROR_WITHOUT_UCE(SetDevice(device)); - - return device; + return cur_device; } bool IsContextInitialized() diff --git a/torch_npu/csrc/core/npu/impl/NPUGuardImpl.h b/torch_npu/csrc/core/npu/impl/NPUGuardImpl.h index cd8a7bebabb5b146293f469c067190428f8ff8ea..66b9b45ab202476eb915f1bc96c609b0007e139a 100644 --- a/torch_npu/csrc/core/npu/impl/NPUGuardImpl.h +++ b/torch_npu/csrc/core/npu/impl/NPUGuardImpl.h @@ -38,13 +38,10 @@ struct NPUGuardImpl final : public c10::impl::DeviceGuardImplInterface { } c10::Device exchangeDevice(c10::Device d) const override { - TORCH_INTERNAL_ASSERT(d.type() == c10::DeviceType::PrivateUse1, + TORCH_INTERNAL_ASSERT(d.is_privateuseone(), "DeviceType must be NPU. Actual DeviceType is: ", d.type(), PTA_ERROR(ErrCode::PARAM)); - c10::Device old_device = getDevice(); - if (old_device.index() != d.index()) { - NPU_CHECK_ERROR_WITHOUT_UCE(c10_npu::SetDevice(d.index())); - } - return old_device; + auto old_device_index = c10_npu::ExchangeDevice(d.index()); + return c10::Device(c10::DeviceType::PrivateUse1, old_device_index); } c10::Device getDevice() const override { diff --git a/torch_npu/csrc/npu/Module.cpp b/torch_npu/csrc/npu/Module.cpp index f4d7a341f8b551f0c4f7d2d4d35fe9916c9076b6..6f3c2c10f659d76d8f17a4a218e711078eb9f311 100644 --- a/torch_npu/csrc/npu/Module.cpp +++ b/torch_npu/csrc/npu/Module.cpp @@ -353,6 +353,22 @@ PyObject* THNPModule_stopDevice_wrap(PyObject* self, PyObject* arg) END_HANDLE_TH_ERRORS } +PyObject* THNPModule_exchangeDevice_wrap(PyObject* self, PyObject* arg) +{ + HANDLE_TH_ERRORS + TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice"); + auto device_index = THPUtils_unpackInt(arg); + if (device_index < 0) { + return THPUtils_packInt32(-1); + } + + torch_npu::utils::npu_lazy_init(); + auto current_device = c10_npu::ExchangeDevice(device_index); + + return THPUtils_packInt64(current_device); + END_HANDLE_TH_ERRORS +} + PyObject* THNPModule_check_uce_in_memory_wrap(PyObject* self, PyObject* arg) { HANDLE_TH_ERRORS @@ -1339,6 +1355,7 @@ static struct PyMethodDef THNPModule_methods[] = { {"_npu_synchronize", (PyCFunction)THNPModule_npuSynchronize, METH_NOARGS, nullptr}, {"_npu_setDevice", (PyCFunction)THNPModule_setDevice_wrap, METH_O, nullptr}, {"_npu_getDevice", (PyCFunction)THNPModule_getDevice_wrap, METH_NOARGS, nullptr}, + {"_npu_exchangeDevice", (PyCFunction)THNPModule_exchangeDevice_wrap, METH_O, nullptr}, {"_npu_stopDevice", (PyCFunction)THNPModule_stopDevice_wrap, METH_O, nullptr}, {"_npu_restart_device", (PyCFunction)THNPModule_restart_device_wrap, METH_O, nullptr}, {"_npu_check_uce_in_memory", (PyCFunction)THNPModule_check_uce_in_memory_wrap, METH_O, nullptr}, diff --git a/torch_npu/npu/__init__.py b/torch_npu/npu/__init__.py index 12d67a0bb8e88c2710b6d2124621aaeca423a2d4..732ec99bfd84e44e75083639531b7636d16b0e31 100644 --- a/torch_npu/npu/__init__.py +++ b/torch_npu/npu/__init__.py @@ -14,6 +14,7 @@ __all__ = [ "is_available", "device", "device_of", + "_DeviceGuard", "stream", "set_stream", "current_stream", @@ -125,7 +126,8 @@ import torch_npu from torch_npu.utils._error_code import ErrCode, pta_error, prof_error from .utils import (synchronize, device_count, can_device_access_peer, set_device, current_device, get_device_name, get_device_properties, get_device_capability, _get_device_index, - device, device_of, stream, set_stream, current_stream, default_stream, set_sync_debug_mode, + device, device_of, _DeviceGuard, + stream, set_stream, current_stream, default_stream, set_sync_debug_mode, get_sync_debug_mode, init_dump, current_blas_handle, is_bf16_supported, utilization, finalize_dump, set_dump, get_npu_overflow_flag, clear_npu_overflow_flag, mem_get_info, check_uce_in_memory, stress_detect) diff --git a/torch_npu/npu/utils.py b/torch_npu/npu/utils.py index 3f3b8493bfbffb52888bb2c8ef0cb6d4d70b543a..78d388d5065d743bf977b46d5b5bcfd598ef5041 100644 --- a/torch_npu/npu/utils.py +++ b/torch_npu/npu/utils.py @@ -132,6 +132,19 @@ def utilization(device=None): return torch_npu._C._npu_getDeviceUtilizationRate(device_id) +class _DeviceGuard(): + def __init__(self, index: int): + self.idx = index + self.prev_idx = -1 + + def __enter__(self): + self.prev_idx = torch_npu._C._exchangeDevice(self.idx) + + def __exit__(self, type: Any, value: Any, traceback: Any): + self.idx = torch_npu._C._exchangeDevice(self.prev_idx) + return False + + class device(object): r"""Context-manager that changes the selected device.