diff --git a/torch_npu/csrc/core/npu/NpuVariables.cpp b/torch_npu/csrc/core/npu/NpuVariables.cpp index 24a2a8da62cf6ecb23684a51be0dadf01b671b04..cbae2a499e9f7402e63b15bfad6c3fde54a86294 100644 --- a/torch_npu/csrc/core/npu/NpuVariables.cpp +++ b/torch_npu/csrc/core/npu/NpuVariables.cpp @@ -104,5 +104,14 @@ bool IsAclnnOnly() { return false; } + +bool IsSupportAclOpLazyInit() +{ + static bool default_support = ((GetSocVersion() >= SocVersion::Ascend910B1) && + (GetSocVersion() < SocVersion::Ascend310B1)) || + (GetSocVersion() >= SocVersion::Ascend910_9391); + return default_support; +} + } // namespace c10_npu diff --git a/torch_npu/csrc/core/npu/NpuVariables.h b/torch_npu/csrc/core/npu/NpuVariables.h index 6a3a8cdfd7e9b8a59fcb5712ab81146dac5be875..3e0846fb13b3f00d9a1a416b6e9f4cb65c05cbdf 100644 --- a/torch_npu/csrc/core/npu/NpuVariables.h +++ b/torch_npu/csrc/core/npu/NpuVariables.h @@ -42,6 +42,9 @@ bool IsSupportInfNan(); bool IsBF16Supported(); bool IsAclnnOnly(); + +bool IsSupportAclOpLazyInit(); + } // namespace c10_npu #endif diff --git a/torch_npu/csrc/core/npu/register/OptionsManager.cpp b/torch_npu/csrc/core/npu/register/OptionsManager.cpp index b5d5e99a76a2a8f3e2c657bb32acef8a554b8d2b..0da57b69a60930b8218e494a5029f96cc4c4daa5 100644 --- a/torch_npu/csrc/core/npu/register/OptionsManager.cpp +++ b/torch_npu/csrc/core/npu/register/OptionsManager.cpp @@ -12,6 +12,7 @@ #include "torch_npu/csrc/core/npu/register/OptionRegister.h" #include "torch_npu/csrc/core/npu/register/OptionsManager.h" #include "torch_npu/csrc/core/npu/NPUFunctions.h" +#include "torch_npu/csrc/core/npu/NpuVariables.h" #include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include "torch_npu/csrc/npu/memory_snapshot.h" @@ -477,19 +478,26 @@ uint32_t OptionsManager::GetP2PBufferSize() return buf_size; } +static uint32_t acl_op_init_mode = []() -> uint32_t { + char* buf_val = std::getenv("ACL_OP_INIT_MODE"); + int64_t default_value = c10_npu::IsSupportAclOpLazyInit() ? 1 : 0; + int64_t acl_op_init_mode_ = (buf_val != nullptr) ? strtol(buf_val, nullptr, 10) : default_value; + std::unordered_map aclOpInitMode = getAclOpInitMode(); + if (aclOpInitMode.find(acl_op_init_mode_) == aclOpInitMode.end()) { + acl_op_init_mode_ = 0; + TORCH_NPU_WARN_ONCE( + "Get env ACL_OP_INIT_MODE not in [0, 1, 2], so reset it to the default value ", default_value, "."); + } + return static_cast(acl_op_init_mode_); +}(); + +void OptionsManager::SetAclOpInitMode(uint32_t val) +{ + acl_op_init_mode = val; +} + uint32_t OptionsManager::GetAclOpInitMode() { - const static uint32_t acl_op_init_mode = []() -> uint32_t { - char* buf_val = std::getenv("ACL_OP_INIT_MODE"); - // Default 0 - int64_t acl_op_init_mode_ = (buf_val != nullptr) ? strtol(buf_val, nullptr, 10) : 0; - std::unordered_map aclOpInitMode = getAclOpInitMode(); - if (aclOpInitMode.find(acl_op_init_mode_) == aclOpInitMode.end()) { - acl_op_init_mode_ = 0; - TORCH_NPU_WARN_ONCE("Get env ACL_OP_INIT_MODE not in [0, 1, 2], so reset it to the default value 0."); - } - return static_cast(acl_op_init_mode_); - }(); return acl_op_init_mode; } diff --git a/torch_npu/csrc/core/npu/register/OptionsManager.h b/torch_npu/csrc/core/npu/register/OptionsManager.h index 73f5dbcb81f9fc268d8ef9122407e66b976dad08..d20d0621aaf9c97e303664cbefb05e60176c51f3 100644 --- a/torch_npu/csrc/core/npu/register/OptionsManager.h +++ b/torch_npu/csrc/core/npu/register/OptionsManager.h @@ -126,6 +126,7 @@ public: static uint32_t GetHcclBufferSize(); static uint32_t GetP2PBufferSize(); static uint32_t GetTaskQueueEnable(); + static void SetAclOpInitMode(uint32_t val); static uint32_t GetAclOpInitMode(); static uint32_t GetStreamsPerDevice(); static char* GetCpuAffinityConf(); diff --git a/torch_npu/csrc/framework/LazyInitAclops.cpp b/torch_npu/csrc/framework/LazyInitAclops.cpp index 440d237fef303f45fcd0883e43298df053492e70..a3960bc9c9ac884a7d3f0f97a7bc06f9688ddfa7 100644 --- a/torch_npu/csrc/framework/LazyInitAclops.cpp +++ b/torch_npu/csrc/framework/LazyInitAclops.cpp @@ -174,6 +174,22 @@ void LazyInitAclopsCore() #endif } +bool IsJitCompileDisable() +{ + static const std::string jit_compile_option_name = "jitCompile"; + auto option_value = c10_npu::option::GetOption(jit_compile_option_name); + if (option_value.has_value() && (option_value.value() == "disable")) { + return true; + } else { + static const std::string jit_compile_init_option_name = "jitCompileInit"; + auto init_option_value = c10_npu::option::GetOption(jit_compile_init_option_name); + if (init_option_value.has_value() && (init_option_value.value() == "disable")) { + return true; + } + } + return false; +} + void LazyInitAclops() { static auto acl_op_init_mode = c10_npu::option::OptionsManager::GetAclOpInitMode(); @@ -186,6 +202,9 @@ void LazyInitAclops() if (!encounteredAclops.exchange(true) && c10_npu::NpuSysCtrl::GetInstance().GetInitFlag()) { RECORD_FUNCTION("LazyInitAclops", std::vector({})); + std::string val = IsJitCompileDisable() ? "disable" : "enable"; + NPU_CHECK_ERROR(at_npu::native::AclSetCompileopt(aclCompileOpt::ACL_OP_JIT_COMPILE, val.c_str())); + ASCEND_LOGI("Set jitCompileInit option to %s", val.c_str()); LazyInitAclopsCore(); ASCEND_LOGI("Lazy init for aclops finished.") } diff --git a/torch_npu/csrc/framework/LazyInitAclops.h b/torch_npu/csrc/framework/LazyInitAclops.h index b842b786522309ec08b548aea5488e6f93d52cc3..fccd8858f4a6555b74b6d9664ca87547db3f987c 100644 --- a/torch_npu/csrc/framework/LazyInitAclops.h +++ b/torch_npu/csrc/framework/LazyInitAclops.h @@ -8,6 +8,8 @@ void InitAclops(); void LazyInitAclops(); void InitializeJitCompilationMode(); +bool IsJitCompileDisable(); + } // namespace aclops } // namespace at_npu diff --git a/torch_npu/csrc/npu/Module.cpp b/torch_npu/csrc/npu/Module.cpp index 65da2f7d4c37c7f8a86b550cc4e1150619104b89..688d0f8469619907be6f6a42635e6467e2bd25c9 100644 --- a/torch_npu/csrc/npu/Module.cpp +++ b/torch_npu/csrc/npu/Module.cpp @@ -34,6 +34,7 @@ #include "torch_npu/csrc/core/npu/register/OptionRegister.h" #include "torch_npu/csrc/core/OverflowUtils.h" #include "torch_npu/csrc/framework/StorageDescHelper.h" +#include "torch_npu/csrc/framework/LazyInitAclops.h" #include "torch_npu/csrc/npu/DataParallelComm.h" #include "torch_npu/csrc/npu/NPUPluggableAllocator.h" #include "torch_npu/csrc/npu/Stream.h" @@ -946,18 +947,10 @@ PyObject *THNPModule_is_jit_compile_false_wrap(PyObject *self, PyObject *noargs) { HANDLE_TH_ERRORS pybind11::gil_scoped_release no_gil; - static const std::string jit_compile_option_name = "jitCompile"; - auto option_value = c10_npu::option::GetOption(jit_compile_option_name); - if (option_value.has_value() && (option_value.value() == "disable")) { + if (at_npu::aclops::IsJitCompileDisable()) { Py_RETURN_TRUE; } else { - static const std::string jit_compile_init_option_name = "jitCompileInit"; - auto init_option_value = c10_npu::option::GetOption(jit_compile_init_option_name); - if (init_option_value.has_value() && (init_option_value.value() == "disable")) { - Py_RETURN_TRUE; - } else { - Py_RETURN_FALSE; - } + Py_RETURN_FALSE; } END_HANDLE_TH_ERRORS } @@ -1830,6 +1823,15 @@ static PyObject* THNPModule_reset_device_res_limit(PyObject* self, PyObject *arg END_HANDLE_TH_ERRORS } +static PyObject* THNPModule_start_acl_op_init(PyObject* self, PyObject *noargs) +{ + HANDLE_TH_ERRORS + c10_npu::option::OptionsManager::SetAclOpInitMode(0); + at_npu::aclops::LazyInitAclops(); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + static struct PyMethodDef THNPModule_methods[] = { {"_npu_init", (PyCFunction)THNPModule_initExtension, METH_NOARGS, nullptr}, {"_npu_set_run_yet_variable_to_false", (PyCFunction)THNPModule_set_run_yet_variable_to_false_wrap, METH_NOARGS, nullptr}, @@ -1898,6 +1900,7 @@ static struct PyMethodDef THNPModule_methods[] = { {"_npu_get_device_res_limit", (PyCFunction)THNPModule_get_device_res_limit, METH_VARARGS, nullptr}, {"_npu_set_device_res_limit", (PyCFunction)THNPModule_set_device_res_limit, METH_VARARGS, nullptr}, {"_npu_reset_device_res_limit", (PyCFunction)THNPModule_reset_device_res_limit, METH_O, nullptr}, + {"_start_acl_op_init", (PyCFunction)THNPModule_start_acl_op_init, METH_NOARGS, nullptr}, {nullptr}}; TORCH_NPU_API PyMethodDef* THNPModule_get_methods() diff --git a/torch_npu/dynamo/__init__.py b/torch_npu/dynamo/__init__.py index a6c235708733bcc8b1fe28e285e47bd821d41a6d..575336ee02634148abc79696a557f07cbece80b6 100644 --- a/torch_npu/dynamo/__init__.py +++ b/torch_npu/dynamo/__init__.py @@ -74,6 +74,8 @@ class _LazyTorchair: try: from . import torchair + import torch_npu + torch_npu._C._start_acl_op_init() except Exception as e: # In cpython, default import loader will suppress error when # find module's __spec__. So here we need to record error and