diff --git a/torch_npu/_inductor/__init__.py b/torch_npu/_inductor/__init__.py index a13b579799544894583ba67db3c2374ef81d4336..d6c56d3e1e577f17104e831b5c98bb88093dc422 100644 --- a/torch_npu/_inductor/__init__.py +++ b/torch_npu/_inductor/__init__.py @@ -1,4 +1,5 @@ import os +import atexit import torch from torch._dynamo.device_interface import register_interface_for_device, get_interface_for_device @@ -6,6 +7,7 @@ from torch._inductor import lowering as inductor_lowering from torch._inductor.choices import InductorChoices from torch._inductor.codegen.common import register_backend_for_device, register_device_op_overrides from torch._inductor.runtime import autotune_cache +from torch._inductor.async_compile import shutdown_compile_workers from torch_npu.npu import device_count from torch_npu.utils._dynamo_device import NpuInterface, current_device, set_device from torch_npu.utils._inductor import NPUDeviceOpOverrides @@ -13,7 +15,7 @@ from torch_npu.utils._inductor import NPUDeviceOpOverrides from . import config as npu_config from . import codegen from .npu_fusion_attention_graph import register_fa_pass -from .config import aggresive_autotune, num_vector_core, set_compile_threads +from .config import aggresive_autotune, num_vector_core from .config import log as npulog from .decomposition import _register_npu_inductor_decompositons from .lowering import make_reduction, npu_make_fallback @@ -22,7 +24,7 @@ from .npu_device import NewNPUDeviceOpOverrides, NewNpuInterface from .runtime import _load_cached_autotuning from .utils import get_current_raw_stream -set_compile_threads() +atexit.register(shutdown_compile_workers) def _inductor_register_backend_for_device(): diff --git a/torch_npu/_inductor/config.py b/torch_npu/_inductor/config.py index f9bf23ee3378200f366c568b1e3cd6ed49a451e8..3a43ee8c2552bcb15d42dc44bd413363eef77e55 100644 --- a/torch_npu/_inductor/config.py +++ b/torch_npu/_inductor/config.py @@ -95,7 +95,7 @@ inductor_static_mode = os.environ.get('INDUCTOR_STATIC_MODE', '0').lower() in (' profile_path = "./profile_result/" -def set_compile_threads(): +def set_compile_threads_to_1(): if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: torchinductor_compile_threads = int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) if torchinductor_compile_threads == 1: