diff --git a/test/test_thread.py b/test/test_thread.py new file mode 100644 index 0000000000000000000000000000000000000000..fd72fe521ff953093fb5435adf1fe2eb2efb327c --- /dev/null +++ b/test/test_thread.py @@ -0,0 +1,140 @@ +import unittest +import os +import tempfile +import threading +import queue +import time +import numpy as np +import torch +import torch_npu + +from torch_npu.utils._path_manager import PathManager +from torch_npu.testing.testcase import TestCase, run_tests + + +class MultiThreadTest(TestCase): + def test_single_forward_backward(self): + lock = threading.Lock() + data = torch.randn(2, 2).npu() + + def forward_op(): + with lock: + t_input = torch.randn(2, 2).npu() + output = torch.mm(t_input, data) + + def backward_op(): + with lock: + grad_output = torch.randn(2, 2).npu() + grad_input = torch.mm(grad_output, data.t()) + + t1 = threading.Thread(target=forward_op) + t2 = threading.Thread(target=backward_op) + t1.start() + t2.start() + t1.join() + t2.join() + + def test_multiple_forward_threads(self): + data = torch.randn(2, 2).npu() + + def forward_op(thread_id): + t_input = torch.randn(2, 2).npu() + output = torch.mm(t_input, data) + + num_threads = 5 + threads = [] + for i in range(num_threads): + t = threading.Thread(target=forward_op, args=(i,)) + threads.append(t) + t.start() + for t in threads: + t.join() + + def test_multiple_forward_backward_threads(self): + data = torch.randn(2, 2).npu() + + def forward_op(thread_id): + t_input = torch.randn(2, 2).npu() + output = torch.mm(t_input, data) + + def backward_op(thread_id): + grad_output = torch.randn(2, 2).npu() + grad_input = torch.mm(grad_output, data.t()) + + num_threads = 5 + threads = [] + for i in range(num_threads): + t = threading.Thread(target=forward_op, args=(i,)) + threads.append(t) + t.start() + for i in range(num_threads): + t = threading.Thread(target=backward_op, args=(i,)) + threads.append(t) + t.start() + for t in threads: + t.join() + + def test_multiple_weight_loading_threads(self): + tmp_dir = tempfile.mkdtemp() + try: + weight_file = os.path.join(tmp_dir, "test_weight.pt") + data = torch.randn(10) + torch.save(data, weight_file) + + def load_weights(thread_id): + weight_path = os.path.join(tmp_dir, "test_weight.pt") + loaded_data = torch.load(weight_path) + self.assertTrue(torch.allclose(loaded_data, data), "Incorrect data loading from hard drive") + data_cpu = loaded_data.cpu() + data_npu = data_cpu.npu() + self.assertTrue(data_npu.device.type == 'npu', "device type mismatch during NPU conversion") + + result = torch.mm(data_npu.unsqueeze(0), data_npu.unsqueeze(1)) + self.assertEqual(result.shape, (1, 1), "Matrix multiplication result shape mismatch") + + num_threads = 5 + threads = [] + for i in range(num_threads): + t = threading.Thread(target=load_weights, args=(i,)) + threads.append(t) + t.start() + for t in threads: + t.join() + finally: + PathManager.remove_path_safety(tmp_dir) + + def test_task_queue_multithreading(self): + task_queue = queue.Queue() + data = torch.randn(2, 2).npu() + + def process_task(thread_id, t_queue): + while True: + try: + task = t_queue.get() + if task is None: + break + t_input = torch.randn(2, 2).npu() + output = torch.mm(t_input, data) + t_queue.task_done() + except queue.Empty: + break + + + num_threads = 3 + threads = [] + for i in range(num_threads): + t = threading.Thread(target=process_task, args=(i, task_queue)) + threads.append(t) + t.start() + tasks = [f"task{i}" for i in range(10)] + for task in tasks: + task_queue.put(task) + task_queue.join() + for _ in range(num_threads): + task_queue.put(None) + for t in threads: + t.join() + + +if __name__ == '__main__': + run_tests() \ No newline at end of file