diff --git a/test/npu/_fault_mode_cases/error_set_device.py b/test/npu/_fault_mode_cases/error_set_device.py index ddda33f34663704bdd2d1ab5647ce1892c14cceb..120994a1a6566ee70f9b4af2b6dd63eff7a8a636 100644 --- a/test/npu/_fault_mode_cases/error_set_device.py +++ b/test/npu/_fault_mode_cases/error_set_device.py @@ -11,7 +11,7 @@ def _worker(i: int) -> None: def set_device(): torch_npu.npu.set_device(0) multiprocessing.set_start_method("spawn", force=True) - jobs = [multiprocessing.Process(target=_worker, args=(i,)) for i in range(70)] + jobs = [multiprocessing.Process(target=_worker, args=(i,)) for i in range(60)] for p in jobs: p.start() @@ -20,4 +20,5 @@ def set_device(): p.join() -set_device() +if __name__ == "__main__": + set_device() diff --git a/test/npu/test_fault_mode.py b/test/npu/test_fault_mode.py index 0d52d5d11c6a55f679676dd6d50ea3e5bf601e99..f77fac8393f31184a42a4af545b40a4a71b743ff 100644 --- a/test/npu/test_fault_mode.py +++ b/test/npu/test_fault_mode.py @@ -1,6 +1,5 @@ import os import subprocess - import torch from torch.testing._internal.common_utils import TestCase, run_tests from torch.utils.checkpoint import checkpoint