diff --git a/tf_adapter/python/npu_bridge/__init__.py b/tf_adapter/python/npu_bridge/__init__.py index 428dd33a067b352d2f35b5107dea1ac97cfd9137..a7e296c5415c183ef5c8296519fa18a5ccf27f21 100644 --- a/tf_adapter/python/npu_bridge/__init__.py +++ b/tf_adapter/python/npu_bridge/__init__.py @@ -23,6 +23,8 @@ from npu_bridge.helper import helper from npu_bridge.estimator.npu import npu_estimator from npu_bridge.hccl import hccl_ops from npu_bridge.estimator.npu.npu_plugin import npu_close +from npu_bridge.estimator.npu import npu_plugin +npu_plugin.set_device_sat_mode(1) atexit.register(npu_close) __all__ = [_s for _s in dir() if not _s.startswith('_')] diff --git a/tf_adapter_2.x/python/npu_device/npu_device.py b/tf_adapter_2.x/python/npu_device/npu_device.py index b0036738f8299905ba5f052c6c28f5eceeed1ae1..0e0752a8e7222026f8d21b1338f8705fedba945f 100644 --- a/tf_adapter_2.x/python/npu_device/npu_device.py +++ b/tf_adapter_2.x/python/npu_device/npu_device.py @@ -168,6 +168,7 @@ def open(device_id=None): """Initiate and return a NPU device handle""" if device_id is None: device_id = int(os.getenv("ASCEND_DEVICE_ID", '0')) + set_device_sat_mode(1) with _npu_ctx_lock: if not isinstance(context.context(), _ContextWithDefaultDevice):