diff --git a/test/npu/_fault_mode_cases/error_set_device.py b/test/npu/_fault_mode_cases/error_set_device.py index ddda33f34663704bdd2d1ab5647ce1892c14cceb..1545ee9123ae20a5f4bb6998fe739df8415dede9 100644 --- a/test/npu/_fault_mode_cases/error_set_device.py +++ b/test/npu/_fault_mode_cases/error_set_device.py @@ -11,7 +11,7 @@ def _worker(i: int) -> None: def set_device(): torch_npu.npu.set_device(0) multiprocessing.set_start_method("spawn", force=True) - jobs = [multiprocessing.Process(target=_worker, args=(i,)) for i in range(70)] + jobs = [multiprocessing.Process(target=_worker, args=(i,)) for i in range(100)] for p in jobs: p.start() @@ -20,4 +20,5 @@ def set_device(): p.join() -set_device() +if __name__ == "__main__": + set_device() diff --git a/test/npu/test_fault_mode.py b/test/npu/test_fault_mode.py index 88bc8cca19d212986e261df9266ab0168502659c..df7650eb4e57301b96719bb0f7f7e1ad6868c5c2 100644 --- a/test/npu/test_fault_mode.py +++ b/test/npu/test_fault_mode.py @@ -1,6 +1,5 @@ import os import subprocess - import torch from torch.testing._internal.common_utils import TestCase, run_tests from torch.utils.checkpoint import checkpoint