diff --git a/test/npu/test_device.py b/test/npu/test_device.py index 156d4c64877ab5e4941a6f0b15891576d7dbcf81..ba32915f1b8859b4faee7068f86a08493d24e1b9 100644 --- a/test/npu/test_device.py +++ b/test/npu/test_device.py @@ -97,22 +97,53 @@ class TestDevice(TestCase): assert isinstance(device, torch._C.device) assert isinstance(device, torch.device) - def test_multithread_device(self): + def test_multithread_device_with_set_device(self): import threading def _worker(result): try: + torch.npu.set_device("npu:0") cur = torch_npu.npu.current_device() self.assertEqual(cur, 0) except Exception: result[0] = 1 - result = [0] - torch.npu.set_device("npu:0") + result = [0, 0] + + try: + torch.npu.set_device("npu:0") + cur = torch_npu.npu.current_device() + self.assertEqual(cur, 0) + except Exception: + result[1] = 1 + thread = threading.Thread(target=_worker, args=(result,)) + thread.start() + thread.join() + self.assertEqual(result[0], 0) + self.assertEqual(result[1], 0) + + def test_multithread_device_with_no_device(self): + import threading + + def _worker(result): + try: + cur = torch_npu.npu.current_device() + self.assertEqual(cur, 0) + except Exception: + result[0] = 1 + + result = [0, 0] + + try: + cur = torch_npu.npu.current_device() + self.assertEqual(cur, 0) + except Exception: + result[1] = 1 thread = threading.Thread(target=_worker, args=(result,)) thread.start() thread.join() self.assertEqual(result[0], 0) + self.assertEqual(result[1], 0) if __name__ == '__main__': diff --git a/torch_npu/acl.json b/torch_npu/acl.json index 8a77faac4fa153432b380a418d81361586790f59..fc0b7aa696d54dc995063905175143b951e09d91 100644 --- a/torch_npu/acl.json +++ b/torch_npu/acl.json @@ -1 +1,5 @@ -{"dump":{"dump_scene":"lite_exception"}} \ No newline at end of file +{ + "dump": { + "dump_scene": "lite_exception" + } +} \ No newline at end of file