From ad91968cb3a32029ca976acfaf1d41cb369d24a6 Mon Sep 17 00:00:00 2001 From: SCh-zx <1325467101@qq.com> Date: Fri, 25 Jul 2025 15:34:30 +0800 Subject: [PATCH 1/2] thread --- test/test_thread.py | 121 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 test/test_thread.py diff --git a/test/test_thread.py b/test/test_thread.py new file mode 100644 index 0000000000..e2ec6bdbcb --- /dev/null +++ b/test/test_thread.py @@ -0,0 +1,121 @@ +import unittest +import os + +import tempfile +import threading +import queue +import time + +import numpy +import torch +import torch_npu + +from torch_npu.utils._path_manager import PathManager +from torch_npu.testing.testcase import TestCase, run_tests + + +class MultiThreadTest(TestCase): + def test_single_forward_backward(self): + lock = threading.Lock() + + def forward_op(): + with lock: + time.sleep(1) + + def backward_op(): + with lock: + time.sleep(1) + + t1 = threading.Thread(target=forward_op) + t2 = threading.Thread(target=backward_op) + t1.start() + t2.start() + t1.join() + t2.join() + + def test_multiple_forward_threads(self): + def forward_op(thread_id): + time.sleep(1) + + num_threads = 5 + threads = [] + for i in range(num_threads): + t = threading.Thread(target=forward_op, args=(i,)) + threads.append(t) + t.start() + for t in threads: + t.join() + + def test_multiple_forward_backward_threads(self): + def forward_op(thread_id): + time.sleep(1) + + def backward_op(thread_id): + time.sleep(1) + + num_threads = 3 + threads = [] + for i in range(num_threads): + t = threading.Thread(target=forward_op, args=(i,)) + threads.append(t) + t.start() + for i in range(num_threads): + t = threading.Thread(target=backward_op, args=(i,)) + threads.append(t) + t.start() + for t in threads: + t.join() + + def test_multiple_weight_loading_threads(self): + tmp_dir = tempfile.mkdtemp() + try: + weight_file = os.path.join(tmp_dir, "test_weight.pt") + data = torch.randn(10) + torch.save(data, weight_file) + + def load_weights(thread_id): + weight_path = os.path.join(tmp_dir, "test_weight.pt") + loaded_data = torch.load(weight_path) + self.assertTrue(torch.allclose(loaded_data, data), "Incorrect data loading from hard drive") + data_cpu = loaded_data.cpu() + data_npu = data_cpu.npu() + self.assertTrue(data_npu.device.type == 'npu', "device type mismatch during NPU conversion") + + num_threads = 4 + threads = [] + for i in range(num_threads): + t = threading.Thread(target=load_weights, args=(i,)) + threads.append(t) + t.start() + for t in threads: + t.join() + finally: + PathManager.remove_path_safety(tmp_dir) + + def test_task_queue_multithreading(self): + task_queue = queue.Queue() + + def process_task(thread_id, t_queue): + while True: + task = t_queue.get() + if task is None: + break + time.sleep(1) + + num_threads = 3 + threads = [] + for i in range(num_threads): + t = threading.Thread(target=process_task, args=(i, task_queue)) + threads.append(t) + t.start() + tasks = [f"task{i}" for i in range(10)] + for task in tasks: + task_queue.put(task) + for _ in range(num_threads): + task_queue.put(None) + for t in threads: + t.join() + + +if __name__ == '__main__': + run_tests() \ No newline at end of file -- Gitee From 527b248308bc58e60268a9e2595795fc3794e71d Mon Sep 17 00:00:00 2001 From: SCh-zx <1325467101@qq.com> Date: Sat, 26 Jul 2025 16:31:40 +0800 Subject: [PATCH 2/2] add --- test/test_thread.py | 45 ++++++++++++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/test/test_thread.py b/test/test_thread.py index e2ec6bdbcb..fd72fe521f 100644 --- a/test/test_thread.py +++ b/test/test_thread.py @@ -1,12 +1,10 @@ import unittest import os - import tempfile import threading import queue import time - -import numpy +import numpy as np import torch import torch_npu @@ -17,14 +15,17 @@ from torch_npu.testing.testcase import TestCase, run_tests class MultiThreadTest(TestCase): def test_single_forward_backward(self): lock = threading.Lock() + data = torch.randn(2, 2).npu() def forward_op(): with lock: - time.sleep(1) + t_input = torch.randn(2, 2).npu() + output = torch.mm(t_input, data) def backward_op(): with lock: - time.sleep(1) + grad_output = torch.randn(2, 2).npu() + grad_input = torch.mm(grad_output, data.t()) t1 = threading.Thread(target=forward_op) t2 = threading.Thread(target=backward_op) @@ -34,8 +35,11 @@ class MultiThreadTest(TestCase): t2.join() def test_multiple_forward_threads(self): + data = torch.randn(2, 2).npu() + def forward_op(thread_id): - time.sleep(1) + t_input = torch.randn(2, 2).npu() + output = torch.mm(t_input, data) num_threads = 5 threads = [] @@ -47,13 +51,17 @@ class MultiThreadTest(TestCase): t.join() def test_multiple_forward_backward_threads(self): + data = torch.randn(2, 2).npu() + def forward_op(thread_id): - time.sleep(1) + t_input = torch.randn(2, 2).npu() + output = torch.mm(t_input, data) def backward_op(thread_id): - time.sleep(1) + grad_output = torch.randn(2, 2).npu() + grad_input = torch.mm(grad_output, data.t()) - num_threads = 3 + num_threads = 5 threads = [] for i in range(num_threads): t = threading.Thread(target=forward_op, args=(i,)) @@ -81,7 +89,10 @@ class MultiThreadTest(TestCase): data_npu = data_cpu.npu() self.assertTrue(data_npu.device.type == 'npu', "device type mismatch during NPU conversion") - num_threads = 4 + result = torch.mm(data_npu.unsqueeze(0), data_npu.unsqueeze(1)) + self.assertEqual(result.shape, (1, 1), "Matrix multiplication result shape mismatch") + + num_threads = 5 threads = [] for i in range(num_threads): t = threading.Thread(target=load_weights, args=(i,)) @@ -94,13 +105,20 @@ class MultiThreadTest(TestCase): def test_task_queue_multithreading(self): task_queue = queue.Queue() + data = torch.randn(2, 2).npu() def process_task(thread_id, t_queue): while True: - task = t_queue.get() - if task is None: + try: + task = t_queue.get() + if task is None: + break + t_input = torch.randn(2, 2).npu() + output = torch.mm(t_input, data) + t_queue.task_done() + except queue.Empty: break - time.sleep(1) + num_threads = 3 threads = [] @@ -111,6 +129,7 @@ class MultiThreadTest(TestCase): tasks = [f"task{i}" for i in range(10)] for task in tasks: task_queue.put(task) + task_queue.join() for _ in range(num_threads): task_queue.put(None) for t in threads: -- Gitee