From 0bb6d3232895556c1faff7446b9bd8b39572a109 Mon Sep 17 00:00:00 2001 From: SCh-zx <1325467101@qq.com> Date: Fri, 29 Aug 2025 08:26:44 +0000 Subject: [PATCH 1/2] add test/test_multi_threads.py. Signed-off-by: SCh-zx <1325467101@qq.com> --- test/test_multi_threads.py | 217 +++++++++++++++++++++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 test/test_multi_threads.py diff --git a/test/test_multi_threads.py b/test/test_multi_threads.py new file mode 100644 index 00000000000..a0ad301c602 --- /dev/null +++ b/test/test_multi_threads.py @@ -0,0 +1,217 @@ +import os +import tempfile +import threading +import queue + +import torch +from torch.testing._internal.common_distributed import MultiProcessTestCase +from torch.testing._internal.common_utils import run_tests +import torch.nn as nn +import numpy as np + +import torch_npu +import torch_npu.utils + +class TestST(MultiProcessTestCase): + def setUp(self): + super().setUp() + self._spawn_processes() + + @property + def world_size(self): + return 1 + + def tearDown(self): + try: + os.remove(self.file_name) + except OSError: + pass + + def _create_model(self): + model = nn.Sequential( + nn.Linear(128, 128), + nn.ReLU(), + nn.Linear(128, 128), + nn.ReLU(), + nn.Linear(128, 128) + ) + return model + + def test_single_forward_backward(self): + def test_func(result): + torch_npu.npu.set_device(0) + model = self._create_model().npu() + input_data = torch.randn(128, 128).npu() + target = torch.randn(128, 128).npu() + output = model(input_data) + loss = ((output - target) ** 2).sum() + loss.backward() + result[0] = 1 + + result = [0] + t = threading.Thread(target=test_func, args=(result, )) + t.start() + t.join() + self.assertEqual(result[0], 1) + + def test_multiple_forward_threads(self): + def test_func(result): + torch_npu.npu.set_device(0) + model = self._create_model().npu() + input_data = torch.randn(128, 128).npu() + output = model(input_data) + result[0] += 1 + + num_threads = 5 + threads = [] + muldevice = True + result = [0] + + for i in range(num_threads): + t = threading.Thread(target=test_func, args=(result, )) + threads.append(t) + t.start() + + for t in threads: + t.join() + self.assertEqual(result[0], 5) + + def test_forward_backward(self): + def test_func(result): + torch_npu.npu.set_device(0) + model = self._create_model().npu() + input_data = torch.randn(128, 128).npu() + target = torch.randn(128, 128).npu() + + output = model(input_data) + loss = ((output - target) ** 2).sum() + + loss.backward() + + has_grad = any(p.grad is not None for p in model.parameters()) + result[0] += 1 if has_grad else 0 + + num_threads = 5 + threads = [] + result = [0] + + for i in range(num_threads): + t = threading.Thread(target=test_func, args=(result,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + self.assertEqual(result[0], num_threads) + + + def test_multiple_weight_loading_threads(self): + def test_func(result): + torch_npu.npu.set_device(0) + input_data = torch.randn(128, 128).npu() + target = torch.randn(128, 128).npu() + tmp_dir = tempfile.mkdtemp() + weight_path = os.path.join(tmp_dir, "model.pt") + + try: + original_model = self._create_model().npu() + input_data = torch.randn(128, 128).npu() + target = torch.randn(128, 128).npu() + + output = original_model(input_data) + loss = (output - target).pow(2).sum() + torch.save(original_model.state_dict(), weight_path) + loaded_state_dict = torch.load(weight_path, map_location="cpu") + + new_model = self._create_model() + new_model.load_state_dict(loaded_state_dict) + new_model = new_model.npu() + + new_model.eval() + input_data = input_data.detach().requires_grad_(True) + output = new_model(input_data) + loss = (output - target).pow(2).sum() + loss.backward() + self.assertIsNotNone(new_model[0].weight.grad) + self.assertFalse(torch.isnan(new_model[0].weight.grad).any()) + result[0] += 1 + + finally: + if os.path.exists(weight_path): + os.remove(weight_path) + if os.path.exists(tmp_dir): + os.rmdir(tmp_dir) + + num_threads = 5 + threads = [] + result = [0] + + for i in range(num_threads): + t = threading.Thread(target=test_func, args=(result,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + self.assertEqual(result[0], num_threads) + + def test_put_task_with_multi_thread(self): + num_tasks = 10 + task_data = [(i, f"task_{i}_data") for i in range(num_tasks)] + + task_queue = queue.Queue(maxsize=num_tasks) + results = [False] * num_tasks + save_files = [] + + def worker_thread(thread_id): + torch_npu.npu.set_device(0) + + while True: + try: + task_id, data = task_queue.get(timeout=5) + + model = self._create_model().npu() + input_data = torch.randn(128, 128).npu() + target = torch.randn(128, 128).npu() + + output = model(input_data) + loss = ((output - target) ** 2).sum() + + loss.backward() + + results[task_id] = True + + temp_file_path = os.path.join(tempfile.gettempdir(), f"task_{task_id}_loss.log") + with open(temp_file_path, 'w') as f: + f.write(f"Task {task_id}: Loss = {loss.item():.6f}\n") + save_files.append(temp_file_path) + + task_queue.task_done() + + except queue.Empty: + break + except Exception as e: + task_queue.task_done() + continue + + threads = [] + for i in range(5): + t = threading.Thread(target=worker_thread, args=(i,)) + t.daemon = True + threads.append(t) + t.start() + + for task_id in task_data: + task_queue.put((task_id)) + + task_queue.join() + + for i in range(len(task_data)): + with self.subTest(task_id=i): + self.assertTrue(results[i], f"Task {i} failed") + + +if __name__ == "__main__": + run_tests() \ No newline at end of file -- Gitee From 73c2a1d242ad8983574a75189808f3c9d85b558b Mon Sep 17 00:00:00 2001 From: SCh-zx <1325467101@qq.com> Date: Fri, 29 Aug 2025 08:37:48 +0000 Subject: [PATCH 2/2] update test/test_multi_threads.py. Signed-off-by: SCh-zx <1325467101@qq.com> --- test/test_multi_threads.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/test_multi_threads.py b/test/test_multi_threads.py index a0ad301c602..8d7143042b9 100644 --- a/test/test_multi_threads.py +++ b/test/test_multi_threads.py @@ -12,6 +12,7 @@ import numpy as np import torch_npu import torch_npu.utils + class TestST(MultiProcessTestCase): def setUp(self): super().setUp() @@ -67,7 +68,7 @@ class TestST(MultiProcessTestCase): muldevice = True result = [0] - for i in range(num_threads): + for _ in range(num_threads): t = threading.Thread(target=test_func, args=(result, )) threads.append(t) t.start() @@ -95,7 +96,7 @@ class TestST(MultiProcessTestCase): threads = [] result = [0] - for i in range(num_threads): + for _ in range(num_threads): t = threading.Thread(target=test_func, args=(result,)) threads.append(t) t.start() @@ -147,7 +148,7 @@ class TestST(MultiProcessTestCase): threads = [] result = [0] - for i in range(num_threads): + for _ in range(num_threads): t = threading.Thread(target=test_func, args=(result,)) threads.append(t) t.start() -- Gitee