From b33fae41a0ff81ae26683be5752aaec7856ed4b9 Mon Sep 17 00:00:00 2001 From: huangyunlong Date: Sat, 23 Aug 2025 18:47:14 +0800 Subject: [PATCH] avoid setdevice 0 when getstream --- torch_npu/csrc/core/npu/NPUStream.cpp | 42 ++++++++++++++------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/torch_npu/csrc/core/npu/NPUStream.cpp b/torch_npu/csrc/core/npu/NPUStream.cpp index c8a3a56ebc4..3e6bc1458b0 100644 --- a/torch_npu/csrc/core/npu/NPUStream.cpp +++ b/torch_npu/csrc/core/npu/NPUStream.cpp @@ -177,7 +177,7 @@ static c10::StreamId NPUStream_getStreamId(const LeakyStreamInternals* ptr) " (something has gone horribly wrong!)", PTA_ERROR(ErrCode::PTR)); } -static void initGlobalStreamState() +static void initGlobalStreamState(c10::DeviceIndex device_id) { num_npus = c10_npu::device_count(); // Check if the number of GPUs matches the expected compile-time max number @@ -189,11 +189,8 @@ static void initGlobalStreamState() C10_COMPILE_TIME_MAX_NPUS, "). Increase that and recompile.", PTA_ERROR(ErrCode::VALUE)); - int device_id = 0; - auto ret = c10_npu::GetDevice(&device_id); - if (ret != ACL_ERROR_NONE) { - ASCEND_LOGE("Device has not been set"); - } + NPUGuard device_guard{device_id}; + LazySetDevice(device_id); // Initializes default streams default_streams[device_id].device_index = device_id; npu_counters[device_id] = 0; @@ -215,6 +212,7 @@ static void initDeviceStreamState(c10::DeviceIndex device_index) // Switches to the requested device so streams are properly associated // with it. NPUGuard device_guard{device_index}; + LazySetDevice(device_index); static int StreamsPerPool = GetStreamsPerPool(); for (auto i = decltype(StreamsPerPool){0}; i < StreamsPerPool; ++i) { auto& npu_streami = npu_streams[device_index][i]; @@ -226,18 +224,20 @@ static void initDeviceStreamState(c10::DeviceIndex device_index) } } -static void initNPUStreamsOnce() +static void initNPUStreamsOnce(c10::DeviceIndex device_index = -1) { // Inits default and secondary streams (once, globally) - c10::DeviceIndex device_index = current_device(); - // makesure on real devcie - SetTargetDevice(); - LazySetDevice(device_index); + if (device_index == -1) { + c10::DeviceIndex device_index = current_device(); + // makesure on real devcie + SetTargetDevice(); + LazySetDevice(device_index); + } c10_npu::NpuSysCtrl::GetInstance().LazyInitialize(); if (!initialize_flag[device_index]) { std::lock_guard lock(mtx[device_index]); if (!initialize_flag[device_index]) { - initGlobalStreamState(); + initGlobalStreamState(device_index); initialize_flag[device_index] = true; } } @@ -353,7 +353,7 @@ aclrtStream NPUStream::stream() const NPUStream getNPUStreamFromPool(c10::DeviceIndex device_index) { - initNPUStreamsOnce(); + initNPUStreamsOnce(device_index); if (device_index == -1) { device_index = current_device(); } @@ -369,7 +369,7 @@ NPUStream getNPUStreamFromPool(c10::DeviceIndex device_index) NPUStream getStreamFromPool(const bool isHighPriority, c10::DeviceIndex device_index) { - initNPUStreamsOnce(); + initNPUStreamsOnce(device_index); if (device_index == -1) { device_index = current_device(); } @@ -390,7 +390,7 @@ NPUStream getStreamFromPool(const bool isHighPriority, c10::DeviceIndex device_i NPUStream getDefaultNPUStream(c10::DeviceIndex device_index) { - initNPUStreamsOnce(); + initNPUStreamsOnce(device_index); if (device_index == -1) { device_index = current_device(); } @@ -399,7 +399,7 @@ NPUStream getDefaultNPUStream(c10::DeviceIndex device_index) NPUStream getCurrentNPUStream(c10::DeviceIndex device_index) { - initNPUStreamsOnce(); + initNPUStreamsOnce(device_index); if (device_index == -1) { device_index = current_device(); } @@ -409,7 +409,7 @@ NPUStream getCurrentNPUStream(c10::DeviceIndex device_index) NPUStream getCurrentSecondaryStream(c10::DeviceIndex device_index) { - initNPUStreamsOnce(); + initNPUStreamsOnce(device_index); if (device_index == -1) { device_index = current_device(); } @@ -419,7 +419,7 @@ NPUStream getCurrentSecondaryStream(c10::DeviceIndex device_index) aclrtStream getCurrentNPUStreamNoWait(c10::DeviceIndex device_index) { - initNPUStreamsOnce(); + initNPUStreamsOnce(device_index); if (device_index == -1) { device_index = current_device(); } @@ -534,7 +534,7 @@ bool npuSynchronizeUsedDevices(bool check_error) void enCurrentNPUStream(void* cur_paras, c10::DeviceIndex device_index) { - initNPUStreamsOnce(); + initNPUStreamsOnce(device_index); if (device_index == -1) { device_index = current_device(); } @@ -617,6 +617,7 @@ void recovery_all_npu_streams(c10::DeviceIndex device_index) return; } NPUGuard device_guard{device_index}; + LazySetDevice(device_index); auto& default_streamsi = default_streams[device_index]; default_streamsi.stream = nullptr; NPU_CHECK_ERROR( @@ -639,6 +640,7 @@ void recovery_all_npu_streams(c10::DeviceIndex device_index) static void initDeviceSyncLaunchStream(c10::DeviceIndex device_index) { NPUGuard device_guard{device_index}; + LazySetDevice(device_index); for (int i = 0; i < kSyncLaunchStreamsPerPool; ++i) { auto& sync_streami = sync_launch_streams[device_index][i]; @@ -653,7 +655,7 @@ static void initDeviceSyncLaunchStream(c10::DeviceIndex device_index) NPUStream getNPUStreamFromSyncLaunchPool(c10::DeviceIndex device_index) { // in order to init num_npus - initNPUStreamsOnce(); + initNPUStreamsOnce(device_index); if (device_index == -1) { device_index = current_device(); } -- Gitee